Skip to content

Commit 69cd230

Browse files
authored
[Feature] Accept **kwargs in update() and update_() (#1677)
1 parent 9792dff commit 69cd230

6 files changed

Lines changed: 334 additions & 20 deletions

File tree

tensordict/_lazy.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3092,19 +3092,41 @@ def expand(self, *args: int, inplace: bool = False) -> Self:
30923092
@lock_blocked
30933093
def update(
30943094
self,
3095-
input_dict_or_td: T,
3095+
input_dict_or_td: T | None = None,
30963096
clone: bool = False,
3097+
inplace: bool = False,
30973098
*,
30983099
keys_to_update: Sequence[NestedKey] | None = None,
30993100
non_blocking: bool = False,
31003101
is_leaf: Callable[[Type], bool] | None = None,
31013102
update_batch_size: bool = False,
3103+
ignore_lock: bool = False,
31023104
**kwargs: Any,
31033105
) -> Self:
31043106
# This implementation of update is compatible with exclusive keys
31053107
# as well as vmapped lazy stacks.
31063108
# We iterate over the tensordicts rather than iterating over the keys,
31073109
# which requires stacking and unbinding but is also not robust to missing keys.
3110+
if kwargs:
3111+
if input_dict_or_td is None:
3112+
input_dict_or_td = kwargs
3113+
elif isinstance(input_dict_or_td, dict):
3114+
input_dict_or_td = {**input_dict_or_td, **kwargs}
3115+
else:
3116+
self.update(
3117+
input_dict_or_td,
3118+
clone=clone,
3119+
inplace=inplace,
3120+
keys_to_update=keys_to_update,
3121+
non_blocking=non_blocking,
3122+
is_leaf=is_leaf,
3123+
update_batch_size=update_batch_size,
3124+
ignore_lock=ignore_lock,
3125+
)
3126+
input_dict_or_td = kwargs
3127+
kwargs = {}
3128+
elif input_dict_or_td is None:
3129+
return self
31083130
if input_dict_or_td is self:
31093131
# no op
31103132
return self
@@ -3157,11 +3179,12 @@ def update(
31573179
td_dest.update(
31583180
td_source,
31593181
clone=clone,
3182+
inplace=inplace,
31603183
keys_to_update=keys_to_update,
31613184
non_blocking=non_blocking,
31623185
is_leaf=is_leaf,
31633186
update_batch_size=update_batch_size,
3164-
**kwargs,
3187+
ignore_lock=ignore_lock,
31653188
)
31663189
return self
31673190

@@ -3185,11 +3208,12 @@ def update(
31853208
return self.update(
31863209
input_dict_or_td.to_lazystack(self.stack_dim),
31873210
clone=clone,
3211+
inplace=inplace,
31883212
keys_to_update=keys_to_update,
31893213
non_blocking=non_blocking,
31903214
is_leaf=is_leaf,
31913215
update_batch_size=update_batch_size,
3192-
**kwargs,
3216+
ignore_lock=ignore_lock,
31933217
)
31943218
# if the batch-size does not permit unbinding, let's first try to reset the batch-size.
31953219
input_dict_or_td = input_dict_or_td.copy()
@@ -3209,10 +3233,12 @@ def update(
32093233
td_dest.update(
32103234
td_source,
32113235
clone=clone,
3236+
inplace=inplace,
32123237
keys_to_update=keys_to_update,
3238+
non_blocking=non_blocking,
32133239
is_leaf=is_leaf,
32143240
update_batch_size=update_batch_size,
3215-
**kwargs,
3241+
ignore_lock=ignore_lock,
32163242
)
32173243
if self.hook_out is not None:
32183244
self_upd = self.hook_out(self_upd)
@@ -3222,12 +3248,30 @@ def update(
32223248

32233249
def update_(
32243250
self,
3225-
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
3251+
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase | None = None,
32263252
clone: bool = False,
32273253
*,
32283254
non_blocking: bool = False,
32293255
**kwargs: Any,
32303256
) -> Self:
3257+
# Extract reserved kwargs that aren't user key-value pairs.
3258+
keys_to_update = kwargs.pop("keys_to_update", None)
3259+
if kwargs:
3260+
if input_dict_or_td is None:
3261+
input_dict_or_td = kwargs
3262+
elif isinstance(input_dict_or_td, dict):
3263+
input_dict_or_td = {**input_dict_or_td, **kwargs}
3264+
else:
3265+
self.update_(
3266+
input_dict_or_td,
3267+
clone=clone,
3268+
non_blocking=non_blocking,
3269+
keys_to_update=keys_to_update,
3270+
)
3271+
input_dict_or_td = kwargs
3272+
kwargs = {}
3273+
elif input_dict_or_td is None:
3274+
return self
32313275
if input_dict_or_td is self:
32323276
# no op
32333277
return self
@@ -3244,7 +3288,12 @@ def update_(
32443288
for td_dest, td_source in _zip_strict(
32453289
self.tensordicts, input_dict_or_td.unbind(self.stack_dim)
32463290
):
3247-
td_dest.update_(td_source, clone=clone, non_blocking=non_blocking, **kwargs)
3291+
td_dest.update_(
3292+
td_source,
3293+
clone=clone,
3294+
non_blocking=non_blocking,
3295+
keys_to_update=keys_to_update,
3296+
)
32483297
return self
32493298

32503299
def update_at_(

tensordict/_td.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4017,7 +4017,7 @@ def _get_tuple(self, key, default, **kwargs):
40174017
@lock_blocked
40184018
def update(
40194019
self,
4020-
input_dict_or_td: dict[str, CompatibleType] | TensorCollection,
4020+
input_dict_or_td: dict[str, CompatibleType] | TensorCollection | None = None,
40214021
clone: bool = False,
40224022
inplace: bool = False,
40234023
*,
@@ -4028,6 +4028,25 @@ def update(
40284028
ignore_lock: bool = False,
40294029
**kwargs,
40304030
) -> _SubTensorDict:
4031+
if kwargs:
4032+
if input_dict_or_td is None:
4033+
input_dict_or_td = kwargs
4034+
elif isinstance(input_dict_or_td, dict):
4035+
input_dict_or_td = {**input_dict_or_td, **kwargs}
4036+
else:
4037+
self.update(
4038+
input_dict_or_td,
4039+
clone=clone,
4040+
inplace=inplace,
4041+
non_blocking=non_blocking,
4042+
keys_to_update=keys_to_update,
4043+
is_leaf=is_leaf,
4044+
update_batch_size=update_batch_size,
4045+
ignore_lock=ignore_lock,
4046+
)
4047+
input_dict_or_td = kwargs
4048+
elif input_dict_or_td is None:
4049+
return self
40314050
if input_dict_or_td is self:
40324051
# no op
40334052
return self
@@ -4109,12 +4128,28 @@ def update(
41094128

41104129
def update_(
41114130
self,
4112-
input_dict: dict[str, CompatibleType] | TensorCollection,
4131+
input_dict: dict[str, CompatibleType] | TensorCollection | None = None,
41134132
clone: bool = False,
41144133
*,
41154134
non_blocking: bool = False,
41164135
keys_to_update: Sequence[NestedKey] | None = None,
4136+
**kwargs,
41174137
) -> _SubTensorDict:
4138+
if kwargs:
4139+
if input_dict is None:
4140+
input_dict = kwargs
4141+
elif isinstance(input_dict, dict):
4142+
input_dict = {**input_dict, **kwargs}
4143+
else:
4144+
self.update_(
4145+
input_dict,
4146+
clone=clone,
4147+
non_blocking=non_blocking,
4148+
keys_to_update=keys_to_update,
4149+
)
4150+
input_dict = kwargs
4151+
elif input_dict is None:
4152+
return self
41184153
return self.update_at_(
41194154
input_dict,
41204155
idx=self.idx,

tensordict/base.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8120,7 +8120,7 @@ def get_item_shape(self, key: NestedKey):
81208120
@lock_blocked
81218121
def update(
81228122
self,
8123-
input_dict_or_td: dict[str, CompatibleType] | T,
8123+
input_dict_or_td: dict[str, CompatibleType] | T | None = None,
81248124
clone: bool = False,
81258125
inplace: bool = False,
81268126
*,
@@ -8129,15 +8129,16 @@ def update(
81298129
is_leaf: Callable[[Type], bool] | None = None,
81308130
update_batch_size: bool = False,
81318131
ignore_lock: bool = False,
8132+
**kwargs,
81328133
) -> Self:
81338134
"""Updates the TensorDict with values from either a dictionary or another TensorDict.
81348135

81358136
.. warning:: `update` will corrupt the data if called within a try/except block. Do not user this method within
81368137
such blocks hoping to catch and patch errors that occur during the execution.
81378138

81388139
Args:
8139-
input_dict_or_td (TensorDictBase or dict): input data to be written
8140-
in self.
8140+
input_dict_or_td (TensorDictBase or dict, optional): input data to be written
8141+
in self. If ``None``, only the keyword arguments are used.
81418142
clone (bool, optional): whether the tensors in the input (
81428143
tensor) dict should be cloned before being set.
81438144
Defaults to ``False``.
@@ -8175,11 +8176,21 @@ def update(
81758176
ignore_lock (bool, optional): if ``True``, any tensordict can be updated regardless of its locked status.
81768177
Defaults to `False`.
81778178

8179+
**kwargs: additional key-value pairs to write into ``self`` as top-level entries. When both
8180+
``input_dict_or_td`` and ``kwargs`` are provided, kwargs win on key conflict.
8181+
81788182
.. note:: When updating a :class:`~tensordict.LazyStackedTensorDict` with N elements with another
81798183
:class:`~tensordict.LazyStackedTensorDict` with M elements, with M > N, along the stack dimension,
81808184
the ``update`` method will append copies of the extra tensordicts to the dest (self) lazy stack.
81818185
This allows users to rely on ``update`` to increment lazy stacks progressively.
81828186

8187+
.. note:: Keyword arguments are treated as top-level keys only. Nested structures still work when
8188+
the kwarg value is itself a dict or tensor collection (e.g. ``td.update(outer={"inner": val})``).
8189+
For deeper tuple keys, use the positional dict form: ``td.update({("a", "b"): val})``.
8190+
Keys whose names collide with reserved parameters (``clone``, ``inplace``, ``non_blocking``,
8191+
``keys_to_update``, ``is_leaf``, ``update_batch_size``, ``ignore_lock``) must be passed through
8192+
the positional dict form.
8193+
81838194
Returns:
81848195
self
81858196

@@ -8193,8 +8204,30 @@ def update(
81938204
>>> other_td = other_td.clone().zero_()
81948205
>>> td.update(other_td)
81958206
>>> assert td['a'] is not other_td['a']
8207+
>>> # keyword form for top-level entries
8208+
>>> td.update(monkey=torch.zeros(3))
8209+
>>> assert (td["monkey"] == 0).all()
81968210

81978211
"""
8212+
if kwargs:
8213+
if input_dict_or_td is None:
8214+
input_dict_or_td = kwargs
8215+
elif isinstance(input_dict_or_td, dict):
8216+
input_dict_or_td = {**input_dict_or_td, **kwargs}
8217+
else:
8218+
self.update(
8219+
input_dict_or_td,
8220+
clone=clone,
8221+
inplace=inplace,
8222+
non_blocking=non_blocking,
8223+
keys_to_update=keys_to_update,
8224+
is_leaf=is_leaf,
8225+
update_batch_size=update_batch_size,
8226+
ignore_lock=ignore_lock,
8227+
)
8228+
input_dict_or_td = kwargs
8229+
elif input_dict_or_td is None:
8230+
return self
81988231
batch_size_changed = False
81998232
if input_dict_or_td is self:
82008233
# no op
@@ -8367,19 +8400,20 @@ def update(
83678400

83688401
def update_(
83698402
self,
8370-
input_dict_or_td: dict[str, CompatibleType] | T,
8403+
input_dict_or_td: dict[str, CompatibleType] | T | None = None,
83718404
clone: bool = False,
83728405
*,
83738406
non_blocking: bool = False,
83748407
keys_to_update: Sequence[NestedKey] | None = None,
8408+
**kwargs,
83758409
) -> Self:
83768410
"""Updates the TensorDict in-place with values from either a dictionary or another TensorDict.
83778411

83788412
Unlike :meth:`~.update`, this function will throw an error if the key is unknown to ``self``.
83798413

83808414
Args:
8381-
input_dict_or_td (TensorDictBase or dict): input data to be written
8382-
in self.
8415+
input_dict_or_td (TensorDictBase or dict, optional): input data to be written
8416+
in self. If ``None``, only the keyword arguments are used.
83838417
clone (bool, optional): whether the tensors in the input (
83848418
tensor) dict should be cloned before being set. Defaults to ``False``.
83858419

@@ -8391,6 +8425,14 @@ def update_(
83918425
non_blocking (bool, optional): if ``True`` and this copy is between
83928426
different devices, the copy may occur asynchronously with respect
83938427
to the host.
8428+
**kwargs: additional key-value pairs to write into ``self`` as top-level entries. As with the
8429+
positional form, any key that does not already exist in ``self`` raises :class:`KeyError`.
8430+
When both ``input_dict_or_td`` and ``kwargs`` are provided, kwargs win on key conflict.
8431+
8432+
.. note:: Keyword arguments are treated as top-level keys only. Nested structures still work when
8433+
the kwarg value is itself a dict or tensor collection. Keys whose names collide with reserved
8434+
parameters (``clone``, ``non_blocking``, ``keys_to_update``) must be passed through the positional
8435+
dict form.
83948436

83958437
Returns:
83968438
self
@@ -8404,8 +8446,26 @@ def update_(
84048446
>>> assert td['a'] is not other_td['a']
84058447
>>> assert (td['a'] == other_td['a']).all()
84068448
>>> assert (td['a'] == 0).all()
8449+
>>> # keyword form for top-level entries (key must already exist)
8450+
>>> td.update_(a=torch.ones(3))
8451+
>>> assert (td["a"] == 1).all()
84078452

84088453
"""
8454+
if kwargs:
8455+
if input_dict_or_td is None:
8456+
input_dict_or_td = kwargs
8457+
elif isinstance(input_dict_or_td, dict):
8458+
input_dict_or_td = {**input_dict_or_td, **kwargs}
8459+
else:
8460+
self.update_(
8461+
input_dict_or_td,
8462+
clone=clone,
8463+
non_blocking=non_blocking,
8464+
keys_to_update=keys_to_update,
8465+
)
8466+
input_dict_or_td = kwargs
8467+
elif input_dict_or_td is None:
8468+
return self
84098469
if input_dict_or_td is self:
84108470
# no op
84118471
return self

0 commit comments

Comments
 (0)