@@ -1889,7 +1889,8 @@ class MyDataNested:
18891889 assert squeeze_tc .y .X .shape == torch .Size ([4 , 5 ])
18901890 assert squeeze_tc .z == squeeze_tc .y .z == z
18911891
1892- @pytest .mark .parametrize ("lazy" , [True , False ])
1892+ @set_capture_non_tensor_stack (False )
1893+ @pytest .mark .parametrize ("lazy" , [True , False , "maybe" ])
18931894 def test_stack (self , lazy ):
18941895 @tensorclass
18951896 class MyDataNested :
@@ -1898,23 +1899,40 @@ class MyDataNested:
18981899 y : "MyDataNested" = None
18991900
19001901 X = torch .ones (3 , 4 , 5 )
1902+ if lazy :
1903+ Xb = torch .randn (3 , 4 , 4 )
1904+ else :
1905+ Xb = X .clone ()
19011906 z = "test_tensorclass"
19021907 batch_size = [3 , 4 ]
19031908 data_nest = MyDataNested (X = X , z = z , batch_size = batch_size )
1909+ data_nest_b = MyDataNested (X = Xb , z = z , batch_size = batch_size )
19041910 data1 = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1905- data2 = MyDataNested (X = X , y = data_nest , z = z , batch_size = batch_size )
1911+ data2 = MyDataNested (X = Xb , y = data_nest_b , z = z , batch_size = batch_size )
19061912
1907- if lazy :
1913+ if lazy is True :
19081914 stacked_tc = LazyStackedTensorDict .lazy_stack ([data1 , data2 ], 0 )
1915+ elif lazy == "maybe" :
1916+ stacked_tc = LazyStackedTensorDict .maybe_dense_stack ([data1 , data2 ], 0 )
19091917 else :
19101918 with set_capture_non_tensor_stack (True ):
19111919 stacked_tc = torch .stack ([data1 , data2 ], 0 )
19121920 assert type (stacked_tc ) is type (data1 )
19131921 assert isinstance (stacked_tc .y , type (data1 .y ))
1914- assert stacked_tc .X .shape == torch .Size ([2 , 3 , 4 , 5 ])
1915- assert stacked_tc .y .X .shape == torch .Size ([2 , 3 , 4 , 5 ])
1916- assert (stacked_tc .X == 1 ).all ()
1917- assert (stacked_tc .y .X == 1 ).all ()
1922+ if not lazy :
1923+ assert stacked_tc .X .shape == torch .Size ([2 , 3 , 4 , 5 ])
1924+ assert stacked_tc .y .X .shape == torch .Size ([2 , 3 , 4 , 5 ])
1925+
1926+ assert (stacked_tc .X == 1 ).all ()
1927+ assert (stacked_tc .y .X == 1 ).all ()
1928+ else :
1929+ assert stacked_tc [0 ].X .shape == torch .Size ([3 , 4 , 5 ])
1930+ assert stacked_tc [0 ].y .X .shape == torch .Size ([3 , 4 , 5 ])
1931+ assert stacked_tc [1 ].X .shape == torch .Size ([3 , 4 , 4 ])
1932+ assert stacked_tc [1 ].y .X .shape == torch .Size ([3 , 4 , 4 ])
1933+ assert (stacked_tc [0 ].X == 1 ).all ()
1934+ assert (stacked_tc [0 ].y .X == 1 ).all ()
1935+
19181936 if lazy_legacy () or lazy :
19191937 assert isinstance (stacked_tc ._tensordict , LazyStackedTensorDict )
19201938 assert isinstance (stacked_tc .y ._tensordict , LazyStackedTensorDict )
0 commit comments