1515import torch
1616
1717from tensordict ._nestedkey import NestedKey
18+ from tensordict ._td import TensorDict
1819from tensordict .base import is_tensor_collection
1920from tensordict .nn .common import dispatch , TensorDictModuleBase
2021from tensordict .nn .distributions import distributions_maps
@@ -811,6 +812,9 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
811812
812813 .. warning:: The behaviour of :attr:`return_composite` will change in v0.9
813814 and default to True from there on.
815+ inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
816+ :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
817+ output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
814818
815819 Raises:
816820 ValueError: If the input sequence of modules is empty.
@@ -945,6 +949,8 @@ def __init__(
945949 modules : OrderedDict [str , TensorDictModuleBase | ProbabilisticTensorDictModule ],
946950 partial_tolerant : bool = False ,
947951 return_composite : bool | None = None ,
952+ * ,
953+ inplace : bool | None = None ,
948954 ) -> None : ...
949955
950956 @overload
@@ -953,13 +959,16 @@ def __init__(
953959 modules : List [TensorDictModuleBase | ProbabilisticTensorDictModule ],
954960 partial_tolerant : bool = False ,
955961 return_composite : bool | None = None ,
962+ * ,
963+ inplace : bool | None = None ,
956964 ) -> None : ...
957965
958966 def __init__ (
959967 self ,
960968 * modules : TensorDictModuleBase | ProbabilisticTensorDictModule ,
961969 partial_tolerant : bool = False ,
962970 return_composite : bool | None = None ,
971+ inplace : bool | None = None ,
963972 ) -> None :
964973 if len (modules ) == 0 :
965974 raise ValueError (
@@ -1004,7 +1013,7 @@ def __init__(
10041013 else :
10051014 self .__dict__ ["_det_part" ] = TensorDictSequential (* modules [:- 1 ])
10061015
1007- super ().__init__ (* modules , partial_tolerant = partial_tolerant )
1016+ super ().__init__ (* modules , partial_tolerant = partial_tolerant , inplace = inplace )
10081017 self .return_composite = return_composite
10091018
10101019 def __getitem__ (self , index : int | slice | str ) -> TensorDictModuleBase :
@@ -1319,6 +1328,14 @@ def forward(
13191328 tensordict_exec = self ._last_module (
13201329 tensordict_exec , _requires_sample = self ._requires_sample
13211330 )
1331+
1332+ if self .inplace is True :
1333+ tensordict_out = tensordict
1334+ elif self .inplace is False :
1335+ tensordict_out = TensorDict ()
1336+ elif self .inplace == "empty" :
1337+ tensordict_out = tensordict .empty ()
1338+
13221339 if tensordict_out is not None :
13231340 result = tensordict_out
13241341 result .update (tensordict_exec , keys_to_update = self .out_keys )
@@ -1336,7 +1353,7 @@ def forward(
13361353 ]
13371354 else :
13381355 keys = list (set (self .out_keys + list (tensordict .keys (True , True ))))
1339- return tensordict .update (result , keys_to_update = keys )
1356+ return tensordict_out .update (result , keys_to_update = keys )
13401357 return result
13411358
13421359
0 commit comments