3434import torch
3535
3636from tensordict .memmap import MemoryMappedTensor
37+ from torch .nn .utils .rnn import pad_sequence
3738
3839try :
3940 from functorch import dim as ftdim
@@ -516,7 +517,7 @@ def get_item_shape(self, key):
516517 return item .shape
517518 except RuntimeError as err :
518519 if re .match (
519- r"Found more than one unique shape in the tensors|Could not run 'aten::stack' with arguments from the " ,
520+ r"Failed to stack tensors within a tensordict " ,
520521 str (err ),
521522 ):
522523 shape = None
@@ -1057,11 +1058,87 @@ def _maybe_get_list(self, key):
10571058 vals .append (val )
10581059 return vals
10591060
1061+ def get (
1062+ self ,
1063+ key : NestedKey ,
1064+ * args ,
1065+ as_list : bool = False ,
1066+ as_padded_tensor : bool = False ,
1067+ as_nested_tensor : bool = False ,
1068+ padding_side : str = "right" ,
1069+ layout : torch .layout = None ,
1070+ padding_value : float | int | bool = 0.0 ,
1071+ ** kwargs ,
1072+ ) -> CompatibleType :
1073+ """Gets the value stored with the input key.
1074+
1075+ Args:
1076+ key (str, tuple of str): key to be queried. If tuple of str it is
1077+ equivalent to chained calls of getattr.
1078+ default: default value if the key is not found in the tensordict. Defaults to ``None``.
1079+
1080+ .. warning::
1081+ Previously, if a key was not present in the tensordict and no default
1082+ was passed, a `KeyError` was raised. From v0.7, this behaviour has been changed
1083+ and a `None` value is returned instead (in accordance with the what dict.get behavior).
1084+ To adopt the old behavior, set the environment variable `export TD_GET_DEFAULTS_TO_NONE='0'` or call
1085+ :func`~tensordict.set_get_defaults_to_none(False)`.
1086+
1087+ Keyword Args:
1088+ as_list (bool, optional): if ``True``, ragged tensors will be returned as list.
1089+ Exclusive with `as_padded_tensor` and `as_nested_tensor`.
1090+ Defaults to ``False``.
1091+ as_padded_tensor (bool, optional): if ``True``, ragged tensors will be returned as padded tensors.
1092+ The padding value can be controlled via the `padding_value` keyword argument, and the padding
1093+ side via the `padding_side` argument.
1094+ Exclusive with `as_list` and `as_nested_tensor`.
1095+ Defaults to ``False``.
1096+ as_nested_tensor (bool, optional): if ``True``, ragged tensors will be returned as list.
1097+ Exclusive with `as_list` and `as_padded_tensor`.
1098+ The layout can be controlled via the `torch.layout` argument.
1099+ Defaults to ``False``.
1100+ layout (torch.layout, optional): the layout when `as_nested_tensor=True`.
1101+ padding_side (str): The side of padding. Must be `"left"` or `"right"`. Defaults to `"right"`.
1102+ padding_value (scalar or bool, optional): The padding value. Defaults to 0.0.
1103+
1104+ Examples:
1105+ >>> from tensordict import TensorDict, lazy_stack
1106+ >>> import torch
1107+ >>> td = lazy_stack([
1108+ ... TensorDict({"x": torch.ones(1,)}),
1109+ ... TensorDict({"x": torch.ones(2,) * 2}),
1110+ ... ])
1111+ >>> td.get("x", as_nested_tensor=True)
1112+ NestedTensor(size=(2, j1), offsets=tensor([0, 1, 3]), contiguous=True)
1113+ >>> td.get("x", as_padded_tensor=True)
1114+ tensor([[1., 0.],
1115+ [2., 2.]])
1116+
1117+ """
1118+ return super ().get (
1119+ key ,
1120+ * args ,
1121+ as_list = as_list ,
1122+ as_padded_tensor = as_padded_tensor ,
1123+ as_nested_tensor = as_nested_tensor ,
1124+ padding_side = padding_side ,
1125+ layout = layout ,
1126+ padding_value = padding_value ,
1127+ ** kwargs ,
1128+ )
1129+
10601130 @cache # noqa: B019
10611131 def _get_str (
10621132 self ,
10631133 key : NestedKey ,
10641134 default : Any = NO_DEFAULT ,
1135+ * ,
1136+ as_list : bool = False ,
1137+ as_padded_tensor : bool = False ,
1138+ as_nested_tensor : bool = False ,
1139+ padding_side : str = "right" ,
1140+ layout : torch .layout = None ,
1141+ padding_value : float | int | bool = 0.0 ,
10651142 ) -> CompatibleType :
10661143 # we can handle the case where the key is a tuple of length 1
10671144 tensors = []
@@ -1076,7 +1153,15 @@ def _get_str(
10761153 return default
10771154 try :
10781155 out = self .lazy_stack (
1079- tensors , self .stack_dim , stack_dim_name = self ._td_dim_name
1156+ tensors ,
1157+ self .stack_dim ,
1158+ stack_dim_name = self ._td_dim_name ,
1159+ as_list = as_list ,
1160+ as_padded_tensor = as_padded_tensor ,
1161+ as_nested_tensor = as_nested_tensor ,
1162+ padding_side = padding_side ,
1163+ layout = layout ,
1164+ padding_value = padding_value ,
10801165 )
10811166 if _is_tensor_collection (type (out )):
10821167 if isinstance (out , LazyStackedTensorDict ):
@@ -1118,8 +1203,8 @@ def _get_str(
11181203 else :
11191204 raise err
11201205
1121- def _get_tuple (self , key , default ):
1122- first = self ._get_str (key [0 ], None )
1206+ def _get_tuple (self , key , default , ** kwargs ):
1207+ first = self ._get_str (key [0 ], None , ** kwargs )
11231208 if first is None :
11241209 return self ._default_get (key [0 ], default )
11251210 if len (key ) == 1 :
@@ -1130,7 +1215,7 @@ def _get_tuple(self, key, default):
11301215 raise ValueError (f"Got too many keys for a KJT: { key } ." )
11311216 return first [key [- 1 ]]
11321217 else :
1133- return first ._get_tuple (key [1 :], default = default )
1218+ return first ._get_tuple (key [1 :], default = default , ** kwargs )
11341219 except AttributeError as err :
11351220 if "has no attribute" in str (err ):
11361221 raise ValueError (
@@ -1148,6 +1233,12 @@ def lazy_stack(
11481233 out : T | None = None ,
11491234 stack_dim_name : str | None = None ,
11501235 strict_shape : bool = False ,
1236+ as_list : bool = False ,
1237+ as_padded_tensor : bool = False ,
1238+ as_nested_tensor : bool = False ,
1239+ padding_side : str = "right" ,
1240+ layout : torch .layout | None = None ,
1241+ padding_value : float | int | bool = 0.0 ,
11511242 ) -> T : # noqa: D417
11521243 """Stacks tensordicts in a LazyStackedTensorDict.
11531244
@@ -1164,13 +1255,55 @@ def lazy_stack(
11641255 stack_dim_name (str, optional): a name for the stacked dimension.
11651256 strict_shape (bool, optional): if ``True``, every tensordict's shapes must match.
11661257 Defaults to ``False``.
1258+ as_list (bool, optional): if ``True``, ragged tensors will be returned as list.
1259+ Exclusive with `as_padded_tensor` and `as_nested_tensor`.
1260+ Defaults to ``False``.
1261+ as_padded_tensor (bool, optional): if ``True``, ragged tensors will be returned as padded tensors.
1262+ The padding value can be controlled via the `padding_value` keyword argument, and the padding
1263+ side via the `padding_side` argument.
1264+ Exclusive with `as_list` and `as_nested_tensor`.
1265+ Defaults to ``False``.
1266+ as_nested_tensor (bool, optional): if ``True``, ragged tensors will be returned as list.
1267+ Exclusive with `as_list` and `as_padded_tensor`.
1268+ The layout can be controlled via the `torch.layout` argument.
1269+ Defaults to ``False``.
1270+ layout (torch.layout, optional): the layout when `as_nested_tensor=True`.
1271+ padding_side (str): The side of padding. Must be `"left"` or `"right"`. Defaults to `"right"`.
1272+ padding_value (scalar or bool, optional): The padding value. Defaults to 0.0.
11671273
11681274 """
11691275 if not items :
11701276 raise RuntimeError ("items cannot be empty" )
11711277
11721278 if all (isinstance (item , torch .Tensor ) for item in items ):
1173- return torch .stack (items , dim = dim , out = out )
1279+ # This must be implemented here and not in _get_str because we want to leverage this check
1280+ special_return = sum ((as_list , as_padded_tensor , as_nested_tensor ))
1281+ if special_return > 1 :
1282+ raise TypeError (
1283+ "as_list, as_padded_tensor and as_nested_tensor are exclusive."
1284+ )
1285+ elif special_return :
1286+ if as_padded_tensor :
1287+ return pad_sequence (
1288+ items ,
1289+ padding_value = padding_value ,
1290+ padding_side = padding_side ,
1291+ batch_first = True ,
1292+ )
1293+ if as_nested_tensor :
1294+ if layout is None :
1295+ layout = torch .jagged
1296+ return torch .nested .as_nested_tensor (items , layout = layout )
1297+ if as_list :
1298+ return items
1299+ try :
1300+ return torch .stack (items , dim = dim , out = out )
1301+ except RuntimeError as err :
1302+ raise RuntimeError (
1303+ "Failed to stack tensors within a tensordict. You can use nested tensors, "
1304+ "padded tensors or return lists via specialized keyword arguments. "
1305+ "Check the TensorDict.lazy_stack documentation!"
1306+ ) from err
11741307 if all (is_non_tensor (tensordict ) for tensordict in items ):
11751308 # Non-tensor data (Data or Stack) are stacked using NonTensorStack
11761309 # If the content is identical (not equal but same id) this does not
@@ -3521,14 +3654,14 @@ def _rename_subtds(self, names):
35213654 def _change_batch_size (self , new_size : torch .Size ) -> None :
35223655 self ._batch_size = new_size
35233656
3524- def _get_str (self , key , default ):
3525- tensor = self ._source ._get_str (key , default )
3657+ def _get_str (self , key , default , ** kwargs ):
3658+ tensor = self ._source ._get_str (key , default , ** kwargs )
35263659 if tensor is default :
35273660 return tensor
35283661 return self ._transform_value (tensor )
35293662
3530- def _get_tuple (self , key , default ):
3531- tensor = self ._source ._get_tuple (key , default )
3663+ def _get_tuple (self , key , default , ** kwargs ):
3664+ tensor = self ._source ._get_tuple (key , default , ** kwargs )
35323665 if tensor is default :
35333666 return tensor
35343667 return self ._transform_value (tensor )
0 commit comments