Skip to content

Commit a82bec8

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent d9fc3c0 commit a82bec8

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

tensordict/_lazy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,9 @@ def densify(self, *, layout: torch.layout = torch.strided):
15381538
list_of_entries, layout=layout
15391539
)
15401540
else:
1541-
raise NotImplementedError
1541+
raise NotImplementedError(
1542+
f"stack_dim is {self.stack_dim} but not 0. Densify canot be done."
1543+
)
15421544
else:
15431545
tensor = self._get_str(key, None)
15441546
if tensor is not None:

test/test_tensorclass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import sys
1717
import weakref
18+
from dataclasses import field
1819
from multiprocessing import Pool
1920
from pathlib import Path
2021
from tempfile import TemporaryDirectory
@@ -744,6 +745,17 @@ class MyClass2:
744745
assert (a != c.clone().zero_()).any()
745746
assert (c != a.clone().zero_()).any()
746747

748+
def test_field(self):
749+
class Cls(TensorClass):
750+
a: torch.Tensor
751+
b: str
752+
c: dict = field(default_factory=dict)
753+
754+
obj = Cls(a=torch.arange(3), b="abc", batch_size=[3])
755+
assert obj[0].a == obj[1].a - 1
756+
assert obj[0].b == obj[1].b
757+
assert obj[0].c is obj[1].c
758+
747759
def test_from_dataclass(self):
748760
assert is_tensorclass(MyTensorClass_autocast)
749761
assert MyTensorClass_nocast is not MyDataClass

0 commit comments

Comments
 (0)