Skip to content

Commit eb2fd8e

Browse files
author
Vincent Moens
committed
[BugFix] Ensure that maybe_dense_stack preserves the TC type
ghstack-source-id: 8972977 Pull Request resolved: #1252
1 parent 9738142 commit eb2fd8e

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

tensordict/_torch_func.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def _stack(
464464
return type(list_of_tensordicts[0])._stack_non_tensor(
465465
list_of_tensordicts, dim=dim
466466
)
467+
list_of_tensordicts_orig = list_of_tensordicts
467468
if is_tc:
468469
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]
469470
clz = type(list_of_tensordicts[0])
@@ -503,7 +504,7 @@ def _stack(
503504
if maybe_dense_stack:
504505
with set_lazy_legacy(True):
505506
return _stack(
506-
list_of_tensordicts,
507+
list_of_tensordicts_orig,
507508
dim=dim,
508509
maybe_dense_stack=maybe_dense_stack,
509510
)
@@ -537,7 +538,7 @@ def _stack(
537538
lazy_stack_dim += 1
538539
else:
539540
dim = dim - 1
540-
return LazyStackedTensorDict(
541+
result = LazyStackedTensorDict(
541542
*[
542543
_stack(
543544
list(subtds),
@@ -550,12 +551,16 @@ def _stack(
550551
],
551552
stack_dim=lazy_stack_dim,
552553
)
554+
if is_tc:
555+
return clz._from_tensordict(result)
556+
return result
557+
553558
lazy_stack_dim = list_of_tensordicts[0].stack_dim
554559
if dim <= lazy_stack_dim:
555560
lazy_stack_dim += 1
556561
else:
557562
dim = dim - 1
558-
return LazyStackedTensorDict(
563+
result = LazyStackedTensorDict(
559564
*[
560565
_stack(list_of_td, dim, maybe_dense_stack=maybe_dense_stack)
561566
for list_of_td in _zip_strict(
@@ -564,6 +569,9 @@ def _stack(
564569
],
565570
stack_dim=lazy_stack_dim,
566571
)
572+
if is_tc:
573+
return clz._from_tensordict(result)
574+
return result
567575

568576
out = {}
569577
for key in keys:
@@ -594,7 +602,7 @@ def _stack(
594602
if maybe_dense_stack:
595603
with set_lazy_legacy(True):
596604
return _stack(
597-
list_of_tensordicts,
605+
list_of_tensordicts_orig,
598606
dim=dim,
599607
maybe_dense_stack=maybe_dense_stack,
600608
)
@@ -641,6 +649,9 @@ def stack_fn(key, values, is_not_init, is_tensor):
641649
*list_of_tensordicts,
642650
stack_dim=dim,
643651
)
652+
if is_tc:
653+
return td_types[0]._from_tensordict(out)
654+
return out
644655
else:
645656
keys = _check_keys(list_of_tensordicts)
646657
batch_size = list(batch_size)

test/test_tensorclass.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,7 +1889,8 @@ class MyDataNested:
18891889
assert squeeze_tc.y.X.shape == torch.Size([4, 5])
18901890
assert squeeze_tc.z == squeeze_tc.y.z == z
18911891

1892-
@pytest.mark.parametrize("lazy", [True, False])
1892+
@set_capture_non_tensor_stack(False)
1893+
@pytest.mark.parametrize("lazy", [True, False, "maybe"])
18931894
def test_stack(self, lazy):
18941895
@tensorclass
18951896
class MyDataNested:
@@ -1898,23 +1899,40 @@ class MyDataNested:
18981899
y: "MyDataNested" = None
18991900

19001901
X = torch.ones(3, 4, 5)
1902+
if lazy:
1903+
Xb = torch.randn(3, 4, 4)
1904+
else:
1905+
Xb = X.clone()
19011906
z = "test_tensorclass"
19021907
batch_size = [3, 4]
19031908
data_nest = MyDataNested(X=X, z=z, batch_size=batch_size)
1909+
data_nest_b = MyDataNested(X=Xb, z=z, batch_size=batch_size)
19041910
data1 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1905-
data2 = MyDataNested(X=X, y=data_nest, z=z, batch_size=batch_size)
1911+
data2 = MyDataNested(X=Xb, y=data_nest_b, z=z, batch_size=batch_size)
19061912

1907-
if lazy:
1913+
if lazy is True:
19081914
stacked_tc = LazyStackedTensorDict.lazy_stack([data1, data2], 0)
1915+
elif lazy == "maybe":
1916+
stacked_tc = LazyStackedTensorDict.maybe_dense_stack([data1, data2], 0)
19091917
else:
19101918
with set_capture_non_tensor_stack(True):
19111919
stacked_tc = torch.stack([data1, data2], 0)
19121920
assert type(stacked_tc) is type(data1)
19131921
assert isinstance(stacked_tc.y, type(data1.y))
1914-
assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5])
1915-
assert stacked_tc.y.X.shape == torch.Size([2, 3, 4, 5])
1916-
assert (stacked_tc.X == 1).all()
1917-
assert (stacked_tc.y.X == 1).all()
1922+
if not lazy:
1923+
assert stacked_tc.X.shape == torch.Size([2, 3, 4, 5])
1924+
assert stacked_tc.y.X.shape == torch.Size([2, 3, 4, 5])
1925+
1926+
assert (stacked_tc.X == 1).all()
1927+
assert (stacked_tc.y.X == 1).all()
1928+
else:
1929+
assert stacked_tc[0].X.shape == torch.Size([3, 4, 5])
1930+
assert stacked_tc[0].y.X.shape == torch.Size([3, 4, 5])
1931+
assert stacked_tc[1].X.shape == torch.Size([3, 4, 4])
1932+
assert stacked_tc[1].y.X.shape == torch.Size([3, 4, 4])
1933+
assert (stacked_tc[0].X == 1).all()
1934+
assert (stacked_tc[0].y.X == 1).all()
1935+
19181936
if lazy_legacy() or lazy:
19191937
assert isinstance(stacked_tc._tensordict, LazyStackedTensorDict)
19201938
assert isinstance(stacked_tc.y._tensordict, LazyStackedTensorDict)

0 commit comments

Comments
 (0)