Skip to content

Commit 2e3c0a9

Browse files
committed
amend
1 parent df9b803 commit 2e3c0a9

File tree

3 files changed

+841
-99
lines changed

3 files changed

+841
-99
lines changed

tensordict/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4771,6 +4771,11 @@ def is_cuda(self):
47714771
def is_cpu(self):
47724772
return self.device is not None and self.device.type == "cpu"
47734773

4774+
@property
4775+
def is_meta(self):
4776+
dtype = self.dtype
4777+
return dtype is not None and self.dtype == torch.meta
4778+
47744779
# Serialization functionality
47754780
def state_dict(
47764781
self,

tensordict/tensorclass.pyi

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ class TensorClass:
8080
lock: bool = False,
8181
**kwargs,
8282
) -> None: ...
83-
@property
84-
def is_meta(self) -> bool: ...
8583
def __bool__(self) -> bool: ...
8684
def __ne__(self, other: object) -> T: ...
8785
def __xor__(self, other: TensorDictBase | float): ...
@@ -459,9 +457,11 @@ class TensorClass:
459457
def cpu(self, **kwargs) -> T: ...
460458
def cuda(self, device: int | None = None, **kwargs) -> T: ...
461459
@property
462-
def is_cuda(self): ...
460+
def is_cuda(self) -> bool: ...
463461
@property
464-
def is_cpu(self): ...
462+
def is_cpu(self) -> bool: ...
463+
@property
464+
def is_meta(self) -> bool: ...
465465
def state_dict(
466466
self,
467467
destination: Incomplete | None = None,

0 commit comments

Comments
 (0)