Skip to content

Commit e43b3fe

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent c1e6896 commit e43b3fe

File tree

8 files changed

+229
-40
lines changed

8 files changed

+229
-40
lines changed

tensordict/_lazy.py

Lines changed: 143 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import torch
3535

3636
from tensordict.memmap import MemoryMappedTensor
37+
from torch.nn.utils.rnn import pad_sequence
3738

3839
try:
3940
from functorch import dim as ftdim
@@ -516,7 +517,7 @@ def get_item_shape(self, key):
516517
return item.shape
517518
except RuntimeError as err:
518519
if re.match(
519-
r"Found more than one unique shape in the tensors|Could not run 'aten::stack' with arguments from the",
520+
r"Failed to stack tensors within a tensordict",
520521
str(err),
521522
):
522523
shape = None
@@ -1057,11 +1058,87 @@ def _maybe_get_list(self, key):
10571058
vals.append(val)
10581059
return vals
10591060

1061+
def get(
1062+
self,
1063+
key: NestedKey,
1064+
*args,
1065+
as_list: bool = False,
1066+
as_padded_tensor: bool = False,
1067+
as_nested_tensor: bool = False,
1068+
padding_side: str = "right",
1069+
layout: torch.layout = None,
1070+
padding_value: float | int | bool = 0.0,
1071+
**kwargs,
1072+
) -> CompatibleType:
1073+
"""Gets the value stored with the input key.
1074+
1075+
Args:
1076+
key (str, tuple of str): key to be queried. If tuple of str it is
1077+
equivalent to chained calls of getattr.
1078+
default: default value if the key is not found in the tensordict. Defaults to ``None``.
1079+
1080+
.. warning::
1081+
Previously, if a key was not present in the tensordict and no default
1082+
was passed, a `KeyError` was raised. From v0.7, this behaviour has been changed
1083+
and a `None` value is returned instead (in accordance with the what dict.get behavior).
1084+
To adopt the old behavior, set the environment variable `export TD_GET_DEFAULTS_TO_NONE='0'` or call
1085+
:func`~tensordict.set_get_defaults_to_none(False)`.
1086+
1087+
Keyword Args:
1088+
as_list (bool, optional): if ``True``, ragged tensors will be returned as list.
1089+
Exclusive with `as_padded_tensor` and `as_nested_tensor`.
1090+
Defaults to ``False``.
1091+
as_padded_tensor (bool, optional): if ``True``, ragged tensors will be returned as padded tensors.
1092+
The padding value can be controlled via the `padding_value` keyword argument, and the padding
1093+
side via the `padding_side` argument.
1094+
Exclusive with `as_list` and `as_nested_tensor`.
1095+
Defaults to ``False``.
1096+
as_nested_tensor (bool, optional): if ``True``, ragged tensors will be returned as list.
1097+
Exclusive with `as_list` and `as_padded_tensor`.
1098+
The layout can be controlled via the `torch.layout` argument.
1099+
Defaults to ``False``.
1100+
layout (torch.layout, optional): the layout when `as_nested_tensor=True`.
1101+
padding_side (str): The side of padding. Must be `"left"` or `"right"`. Defaults to `"right"`.
1102+
padding_value (scalar or bool, optional): The padding value. Defaults to 0.0.
1103+
1104+
Examples:
1105+
>>> from tensordict import TensorDict, lazy_stack
1106+
>>> import torch
1107+
>>> td = lazy_stack([
1108+
... TensorDict({"x": torch.ones(1,)}),
1109+
... TensorDict({"x": torch.ones(2,) * 2}),
1110+
... ])
1111+
>>> td.get("x", as_nested_tensor=True)
1112+
NestedTensor(size=(2, j1), offsets=tensor([0, 1, 3]), contiguous=True)
1113+
>>> td.get("x", as_padded_tensor=True)
1114+
tensor([[1., 0.],
1115+
[2., 2.]])
1116+
1117+
"""
1118+
return super().get(
1119+
key,
1120+
*args,
1121+
as_list=as_list,
1122+
as_padded_tensor=as_padded_tensor,
1123+
as_nested_tensor=as_nested_tensor,
1124+
padding_side=padding_side,
1125+
layout=layout,
1126+
padding_value=padding_value,
1127+
**kwargs,
1128+
)
1129+
10601130
@cache # noqa: B019
10611131
def _get_str(
10621132
self,
10631133
key: NestedKey,
10641134
default: Any = NO_DEFAULT,
1135+
*,
1136+
as_list: bool = False,
1137+
as_padded_tensor: bool = False,
1138+
as_nested_tensor: bool = False,
1139+
padding_side: str = "right",
1140+
layout: torch.layout = None,
1141+
padding_value: float | int | bool = 0.0,
10651142
) -> CompatibleType:
10661143
# we can handle the case where the key is a tuple of length 1
10671144
tensors = []
@@ -1076,7 +1153,15 @@ def _get_str(
10761153
return default
10771154
try:
10781155
out = self.lazy_stack(
1079-
tensors, self.stack_dim, stack_dim_name=self._td_dim_name
1156+
tensors,
1157+
self.stack_dim,
1158+
stack_dim_name=self._td_dim_name,
1159+
as_list=as_list,
1160+
as_padded_tensor=as_padded_tensor,
1161+
as_nested_tensor=as_nested_tensor,
1162+
padding_side=padding_side,
1163+
layout=layout,
1164+
padding_value=padding_value,
10801165
)
10811166
if _is_tensor_collection(type(out)):
10821167
if isinstance(out, LazyStackedTensorDict):
@@ -1118,8 +1203,8 @@ def _get_str(
11181203
else:
11191204
raise err
11201205

1121-
def _get_tuple(self, key, default):
1122-
first = self._get_str(key[0], None)
1206+
def _get_tuple(self, key, default, **kwargs):
1207+
first = self._get_str(key[0], None, **kwargs)
11231208
if first is None:
11241209
return self._default_get(key[0], default)
11251210
if len(key) == 1:
@@ -1130,7 +1215,7 @@ def _get_tuple(self, key, default):
11301215
raise ValueError(f"Got too many keys for a KJT: {key}.")
11311216
return first[key[-1]]
11321217
else:
1133-
return first._get_tuple(key[1:], default=default)
1218+
return first._get_tuple(key[1:], default=default, **kwargs)
11341219
except AttributeError as err:
11351220
if "has no attribute" in str(err):
11361221
raise ValueError(
@@ -1148,6 +1233,12 @@ def lazy_stack(
11481233
out: T | None = None,
11491234
stack_dim_name: str | None = None,
11501235
strict_shape: bool = False,
1236+
as_list: bool = False,
1237+
as_padded_tensor: bool = False,
1238+
as_nested_tensor: bool = False,
1239+
padding_side: str = "right",
1240+
layout: torch.layout | None = None,
1241+
padding_value: float | int | bool = 0.0,
11511242
) -> T: # noqa: D417
11521243
"""Stacks tensordicts in a LazyStackedTensorDict.
11531244
@@ -1164,13 +1255,55 @@ def lazy_stack(
11641255
stack_dim_name (str, optional): a name for the stacked dimension.
11651256
strict_shape (bool, optional): if ``True``, every tensordict's shapes must match.
11661257
Defaults to ``False``.
1258+
as_list (bool, optional): if ``True``, ragged tensors will be returned as list.
1259+
Exclusive with `as_padded_tensor` and `as_nested_tensor`.
1260+
Defaults to ``False``.
1261+
as_padded_tensor (bool, optional): if ``True``, ragged tensors will be returned as padded tensors.
1262+
The padding value can be controlled via the `padding_value` keyword argument, and the padding
1263+
side via the `padding_side` argument.
1264+
Exclusive with `as_list` and `as_nested_tensor`.
1265+
Defaults to ``False``.
1266+
as_nested_tensor (bool, optional): if ``True``, ragged tensors will be returned as list.
1267+
Exclusive with `as_list` and `as_padded_tensor`.
1268+
The layout can be controlled via the `torch.layout` argument.
1269+
Defaults to ``False``.
1270+
layout (torch.layout, optional): the layout when `as_nested_tensor=True`.
1271+
padding_side (str): The side of padding. Must be `"left"` or `"right"`. Defaults to `"right"`.
1272+
padding_value (scalar or bool, optional): The padding value. Defaults to 0.0.
11671273
11681274
"""
11691275
if not items:
11701276
raise RuntimeError("items cannot be empty")
11711277

11721278
if all(isinstance(item, torch.Tensor) for item in items):
1173-
return torch.stack(items, dim=dim, out=out)
1279+
# This must be implemented here and not in _get_str because we want to leverage this check
1280+
special_return = sum((as_list, as_padded_tensor, as_nested_tensor))
1281+
if special_return > 1:
1282+
raise TypeError(
1283+
"as_list, as_padded_tensor and as_nested_tensor are exclusive."
1284+
)
1285+
elif special_return:
1286+
if as_padded_tensor:
1287+
return pad_sequence(
1288+
items,
1289+
padding_value=padding_value,
1290+
padding_side=padding_side,
1291+
batch_first=True,
1292+
)
1293+
if as_nested_tensor:
1294+
if layout is None:
1295+
layout = torch.jagged
1296+
return torch.nested.as_nested_tensor(items, layout=layout)
1297+
if as_list:
1298+
return items
1299+
try:
1300+
return torch.stack(items, dim=dim, out=out)
1301+
except RuntimeError as err:
1302+
raise RuntimeError(
1303+
"Failed to stack tensors within a tensordict. You can use nested tensors, "
1304+
"padded tensors or return lists via specialized keyword arguments. "
1305+
"Check the TensorDict.lazy_stack documentation!"
1306+
) from err
11741307
if all(is_non_tensor(tensordict) for tensordict in items):
11751308
# Non-tensor data (Data or Stack) are stacked using NonTensorStack
11761309
# If the content is identical (not equal but same id) this does not
@@ -3521,14 +3654,14 @@ def _rename_subtds(self, names):
35213654
def _change_batch_size(self, new_size: torch.Size) -> None:
35223655
self._batch_size = new_size
35233656

3524-
def _get_str(self, key, default):
3525-
tensor = self._source._get_str(key, default)
3657+
def _get_str(self, key, default, **kwargs):
3658+
tensor = self._source._get_str(key, default, **kwargs)
35263659
if tensor is default:
35273660
return tensor
35283661
return self._transform_value(tensor)
35293662

3530-
def _get_tuple(self, key, default):
3531-
tensor = self._source._get_tuple(key, default)
3663+
def _get_tuple(self, key, default, **kwargs):
3664+
tensor = self._source._get_tuple(key, default, **kwargs)
35323665
if tensor is default:
35333666
return tensor
35343667
return self._transform_value(tensor)

tensordict/_td.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2695,19 +2695,19 @@ def _stack_onto_at_(
26952695
# )
26962696
return self
26972697

2698-
def _get_str(self, key, default):
2698+
def _get_str(self, key, default, **kwargs):
26992699
first_key = key
27002700
out = self._tensordict.get(first_key)
27012701
if out is None:
27022702
return self._default_get(first_key, default)
27032703
return out
27042704

2705-
def _get_tuple(self, key, default):
2706-
first = self._get_str(key[0], default)
2705+
def _get_tuple(self, key, default, **kwargs):
2706+
first = self._get_str(key[0], default, **kwargs)
27072707
if len(key) == 1 or first is default:
27082708
return first
27092709
try:
2710-
return first._get_tuple(key[1:], default=default)
2710+
return first._get_tuple(key[1:], default=default, **kwargs)
27112711
except AttributeError as err:
27122712
if "has no attribute" in str(err):
27132713
raise ValueError(
@@ -3823,16 +3823,16 @@ def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
38233823
return out._source
38243824
return out
38253825

3826-
def _get_str(self, key, default):
3826+
def _get_str(self, key, default, **kwargs):
38273827
if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
3828-
data = self._source._get_str(key, NO_DEFAULT)
3828+
data = self._source._get_str(key, NO_DEFAULT, **kwargs)
38293829
if _pass_through(data):
38303830
return data[self.idx]
38313831
return _SubTensorDict(data, self.idx)
3832-
return self._source._get_at_str(key, self.idx, default=default)
3832+
return self._source._get_at_str(key, self.idx, default=default, **kwargs)
38333833

3834-
def _get_tuple(self, key, default):
3835-
return self._source._get_at_tuple(key, self.idx, default=default)
3834+
def _get_tuple(self, key, default, **kwargs):
3835+
return self._source._get_at_tuple(key, self.idx, default=default, **kwargs)
38363836

38373837
@lock_blocked
38383838
def update(

tensordict/base.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6436,6 +6436,9 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
64366436
To adopt the old behavior, set the environment variable `export TD_GET_DEFAULTS_TO_NONE='0'` or call
64376437
:func`~tensordict.set_get_defaults_to_none(False)`.
64386438

6439+
.. note:: Keyword arguments can be passed to :meth:`~.get` when dealing with ragged tensors.
6440+
See :meth:`~tensordict.LazyStackedTensorDict.get` for a complete overview.
6441+
64396442
Examples:
64406443
>>> td = TensorDict({"x": 1}, batch_size=[])
64416444
>>> td.get("x")
@@ -6449,26 +6452,28 @@ def get(self, key: NestedKey, *args, **kwargs) -> CompatibleType:
64496452
# Find what the default is
64506453
if args:
64516454
default = args[0]
6452-
if len(args) > 1 or kwargs:
6453-
raise TypeError("only one (keyword) argument is allowed.")
6454-
elif kwargs:
6455+
if len(args) > 1:
6456+
raise TypeError("Only one arg is allowed in TD.get.")
6457+
elif "default" in kwargs:
6458+
raise TypeError("'default' arg was passed twice.")
6459+
elif "default" in kwargs:
64556460
default = kwargs.pop("default")
6456-
if args or kwargs:
6457-
raise TypeError("only one (keyword) argument is allowed.")
6461+
if args:
6462+
raise TypeError("'default' arg was passed twice.")
64586463
elif _GET_DEFAULTS_TO_NONE:
64596464
default = None
64606465
else:
64616466
default = NO_DEFAULT
6462-
return self._get_tuple(key, default=default)
6467+
return self._get_tuple(key, default=default, **kwargs)
64636468

64646469
@abc.abstractmethod
6465-
def _get_str(self, key, default): ...
6470+
def _get_str(self, key, default, **kwargs): ...
64666471

64676472
@abc.abstractmethod
6468-
def _get_tuple(self, key, default): ...
6473+
def _get_tuple(self, key, default, **kwargs): ...
64696474

6470-
def _get_tuple_maybe_non_tensor(self, key, default):
6471-
result = self._get_tuple(key, default)
6475+
def _get_tuple_maybe_non_tensor(self, key, default, **kwargs):
6476+
result = self._get_tuple(key, default, **kwargs)
64726477
if _pass_through(result):
64736478
# Only lazy stacks of non tensors are actually tensordict instances
64746479
if isinstance(result, TensorDictBase):

tensordict/nn/common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,9 @@ class TensorDictModule(TensorDictModuleBase):
851851
method_kwargs (Dict[str, Any], optional): additional keyword arguments to be passed to the module's method being called.
852852
strict (bool, optional): if ``True``, the module will raise an exception if any of the inputs is missing from
853853
the input tensordict. Otherwise, a `None` value will be used as placeholder. Defaults to ``False``.
854+
get_kwargs (dict[str, Any], optional): additional keyword arguments to be passed to the :meth:`~tensordict.TensorDictBase.get`
855+
method. This is particularily useful when dealing with ragged tensors (see :meth:`~tensordict.LazyStackedTensorDict.get`).
856+
Defaults to ``{}``.
854857
855858
Embedding a neural network in a TensorDictModule only requires to specify the input
856859
and output keys. TensorDictModule support functional and regular :obj:`nn.Module`
@@ -1018,6 +1021,7 @@ def __init__(
10181021
method: str | None = None,
10191022
method_kwargs: dict | None = None,
10201023
strict: bool = False,
1024+
get_kwargs: dict | None = None,
10211025
) -> None:
10221026
super().__init__()
10231027

@@ -1097,6 +1101,7 @@ def __init__(
10971101
self.inplace = inplace
10981102
self.method = method
10991103
self.method_kwargs = method_kwargs if method_kwargs is not None else {}
1104+
self._get_kwargs = get_kwargs if get_kwargs is not None else {}
11001105

11011106
@property
11021107
def is_functional(self) -> bool:
@@ -1180,7 +1185,9 @@ def forward(
11801185
else:
11811186
tensors = tuple(
11821187
tensordict._get_tuple_maybe_non_tensor(
1183-
_unravel_key_to_tuple(in_key), default
1188+
_unravel_key_to_tuple(in_key),
1189+
default,
1190+
**self._get_kwargs,
11841191
)
11851192
for in_key in self.in_keys
11861193
)
@@ -1223,7 +1230,7 @@ def forward(
12231230
import inspect
12241231

12251232
module = inspect.getsource(module)
1226-
except OSError:
1233+
except Exception:
12271234
# then we can't print the source code
12281235
pass
12291236
module = indent(str(module), 4 * " ")

tensordict/nn/probabilistic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,14 @@ def __init__(
984984
self._ordered_dict = True
985985
else:
986986
modules = modules_list = list(modules[0])
987+
elif len(modules) == 1 and isinstance(modules[0], dict):
988+
modules = [collections.OrderedDict(modules[0])]
989+
return self.__init__(
990+
*modules,
991+
partial_tolerant=partial_tolerant,
992+
return_composite=return_composite,
993+
inplace=inplace,
994+
)
987995
elif not return_composite and not isinstance(
988996
modules[-1],
989997
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),

tensordict/persistent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def _process_array(self, key, array):
354354
return out
355355

356356
@cache # noqa: B019
357-
def _get_str(self, key: NestedKey, default):
357+
def _get_str(self, key: NestedKey, default, **kwargs):
358358
key = _unravel_key_to_tuple(key)
359359
array = self._get_array(key, default)
360360
if array is default:

0 commit comments

Comments
 (0)