@@ -8120,7 +8120,7 @@ def get_item_shape(self, key: NestedKey):
81208120 @lock_blocked
81218121 def update(
81228122 self,
8123- input_dict_or_td: dict[str, CompatibleType] | T,
8123+ input_dict_or_td: dict[str, CompatibleType] | T | None = None ,
81248124 clone: bool = False,
81258125 inplace: bool = False,
81268126 *,
@@ -8129,15 +8129,16 @@ def update(
81298129 is_leaf: Callable[[Type], bool] | None = None,
81308130 update_batch_size: bool = False,
81318131 ignore_lock: bool = False,
8132+ **kwargs,
81328133 ) -> Self:
81338134 """Updates the TensorDict with values from either a dictionary or another TensorDict.
81348135
81358136 .. warning:: `update` will corrupt the data if called within a try/except block. Do not user this method within
81368137 such blocks hoping to catch and patch errors that occur during the execution.
81378138
81388139 Args:
8139- input_dict_or_td (TensorDictBase or dict): input data to be written
8140- in self.
8140+ input_dict_or_td (TensorDictBase or dict, optional ): input data to be written
8141+ in self. If ``None``, only the keyword arguments are used.
81418142 clone (bool, optional): whether the tensors in the input (
81428143 tensor) dict should be cloned before being set.
81438144 Defaults to ``False``.
@@ -8175,11 +8176,21 @@ def update(
81758176 ignore_lock (bool, optional): if ``True``, any tensordict can be updated regardless of its locked status.
81768177 Defaults to `False`.
81778178
8179+ **kwargs: additional key-value pairs to write into ``self`` as top-level entries. When both
8180+ ``input_dict_or_td`` and ``kwargs`` are provided, kwargs win on key conflict.
8181+
81788182 .. note:: When updating a :class:`~tensordict.LazyStackedTensorDict` with N elements with another
81798183 :class:`~tensordict.LazyStackedTensorDict` with M elements, with M > N, along the stack dimension,
81808184 the ``update`` method will append copies of the extra tensordicts to the dest (self) lazy stack.
81818185 This allows users to rely on ``update`` to increment lazy stacks progressively.
81828186
8187+ .. note:: Keyword arguments are treated as top-level keys only. Nested structures still work when
8188+ the kwarg value is itself a dict or tensor collection (e.g. ``td.update(outer={"inner": val})``).
8189+ For deeper tuple keys, use the positional dict form: ``td.update({("a", "b"): val})``.
8190+ Keys whose names collide with reserved parameters (``clone``, ``inplace``, ``non_blocking``,
8191+ ``keys_to_update``, ``is_leaf``, ``update_batch_size``, ``ignore_lock``) must be passed through
8192+ the positional dict form.
8193+
81838194 Returns:
81848195 self
81858196
@@ -8193,8 +8204,30 @@ def update(
81938204 >>> other_td = other_td.clone().zero_()
81948205 >>> td.update(other_td)
81958206 >>> assert td['a'] is not other_td['a']
8207+ >>> # keyword form for top-level entries
8208+ >>> td.update(monkey=torch.zeros(3))
8209+ >>> assert (td["monkey"] == 0).all()
81968210
81978211 """
8212+ if kwargs:
8213+ if input_dict_or_td is None:
8214+ input_dict_or_td = kwargs
8215+ elif isinstance(input_dict_or_td, dict):
8216+ input_dict_or_td = {**input_dict_or_td, **kwargs}
8217+ else:
8218+ self.update(
8219+ input_dict_or_td,
8220+ clone=clone,
8221+ inplace=inplace,
8222+ non_blocking=non_blocking,
8223+ keys_to_update=keys_to_update,
8224+ is_leaf=is_leaf,
8225+ update_batch_size=update_batch_size,
8226+ ignore_lock=ignore_lock,
8227+ )
8228+ input_dict_or_td = kwargs
8229+ elif input_dict_or_td is None:
8230+ return self
81988231 batch_size_changed = False
81998232 if input_dict_or_td is self:
82008233 # no op
@@ -8367,19 +8400,20 @@ def update(
83678400
83688401 def update_(
83698402 self,
8370- input_dict_or_td: dict[str, CompatibleType] | T,
8403+ input_dict_or_td: dict[str, CompatibleType] | T | None = None ,
83718404 clone: bool = False,
83728405 *,
83738406 non_blocking: bool = False,
83748407 keys_to_update: Sequence[NestedKey] | None = None,
8408+ **kwargs,
83758409 ) -> Self:
83768410 """Updates the TensorDict in-place with values from either a dictionary or another TensorDict.
83778411
83788412 Unlike :meth:`~.update`, this function will throw an error if the key is unknown to ``self``.
83798413
83808414 Args:
8381- input_dict_or_td (TensorDictBase or dict): input data to be written
8382- in self.
8415+ input_dict_or_td (TensorDictBase or dict, optional ): input data to be written
8416+ in self. If ``None``, only the keyword arguments are used.
83838417 clone (bool, optional): whether the tensors in the input (
83848418 tensor) dict should be cloned before being set. Defaults to ``False``.
83858419
@@ -8391,6 +8425,14 @@ def update_(
83918425 non_blocking (bool, optional): if ``True`` and this copy is between
83928426 different devices, the copy may occur asynchronously with respect
83938427 to the host.
8428+ **kwargs: additional key-value pairs to write into ``self`` as top-level entries. As with the
8429+ positional form, any key that does not already exist in ``self`` raises :class:`KeyError`.
8430+ When both ``input_dict_or_td`` and ``kwargs`` are provided, kwargs win on key conflict.
8431+
8432+ .. note:: Keyword arguments are treated as top-level keys only. Nested structures still work when
8433+ the kwarg value is itself a dict or tensor collection. Keys whose names collide with reserved
8434+ parameters (``clone``, ``non_blocking``, ``keys_to_update``) must be passed through the positional
8435+ dict form.
83948436
83958437 Returns:
83968438 self
@@ -8404,8 +8446,26 @@ def update_(
84048446 >>> assert td['a'] is not other_td['a']
84058447 >>> assert (td['a'] == other_td['a']).all()
84068448 >>> assert (td['a'] == 0).all()
8449+ >>> # keyword form for top-level entries (key must already exist)
8450+ >>> td.update_(a=torch.ones(3))
8451+ >>> assert (td["a"] == 1).all()
84078452
84088453 """
8454+ if kwargs:
8455+ if input_dict_or_td is None:
8456+ input_dict_or_td = kwargs
8457+ elif isinstance(input_dict_or_td, dict):
8458+ input_dict_or_td = {**input_dict_or_td, **kwargs}
8459+ else:
8460+ self.update_(
8461+ input_dict_or_td,
8462+ clone=clone,
8463+ non_blocking=non_blocking,
8464+ keys_to_update=keys_to_update,
8465+ )
8466+ input_dict_or_td = kwargs
8467+ elif input_dict_or_td is None:
8468+ return self
84098469 if input_dict_or_td is self:
84108470 # no op
84118471 return self
0 commit comments