Skip to content

Commit c1e6896

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 58ccbf5 commit c1e6896

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

tensordict/nn/probabilistic.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
from tensordict._nestedkey import NestedKey
18+
from tensordict._td import TensorDict
1819
from tensordict.base import is_tensor_collection
1920
from tensordict.nn.common import dispatch, TensorDictModuleBase
2021
from tensordict.nn.distributions import distributions_maps
@@ -811,6 +812,9 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
811812
812813
.. warning:: The behaviour of :attr:`return_composite` will change in v0.9
813814
and default to True from there on.
815+
inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
816+
:class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
817+
output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
814818
815819
Raises:
816820
ValueError: If the input sequence of modules is empty.
@@ -945,6 +949,8 @@ def __init__(
945949
modules: OrderedDict[str, TensorDictModuleBase | ProbabilisticTensorDictModule],
946950
partial_tolerant: bool = False,
947951
return_composite: bool | None = None,
952+
*,
953+
inplace: bool | None = None,
948954
) -> None: ...
949955

950956
@overload
@@ -953,13 +959,16 @@ def __init__(
953959
modules: List[TensorDictModuleBase | ProbabilisticTensorDictModule],
954960
partial_tolerant: bool = False,
955961
return_composite: bool | None = None,
962+
*,
963+
inplace: bool | None = None,
956964
) -> None: ...
957965

958966
def __init__(
959967
self,
960968
*modules: TensorDictModuleBase | ProbabilisticTensorDictModule,
961969
partial_tolerant: bool = False,
962970
return_composite: bool | None = None,
971+
inplace: bool | None = None,
963972
) -> None:
964973
if len(modules) == 0:
965974
raise ValueError(
@@ -1004,7 +1013,7 @@ def __init__(
10041013
else:
10051014
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])
10061015

1007-
super().__init__(*modules, partial_tolerant=partial_tolerant)
1016+
super().__init__(*modules, partial_tolerant=partial_tolerant, inplace=inplace)
10081017
self.return_composite = return_composite
10091018

10101019
def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
@@ -1319,6 +1328,14 @@ def forward(
13191328
tensordict_exec = self._last_module(
13201329
tensordict_exec, _requires_sample=self._requires_sample
13211330
)
1331+
1332+
if self.inplace is True:
1333+
tensordict_out = tensordict
1334+
elif self.inplace is False:
1335+
tensordict_out = TensorDict()
1336+
elif self.inplace == "empty":
1337+
tensordict_out = tensordict.empty()
1338+
13221339
if tensordict_out is not None:
13231340
result = tensordict_out
13241341
result.update(tensordict_exec, keys_to_update=self.out_keys)
@@ -1336,7 +1353,7 @@ def forward(
13361353
]
13371354
else:
13381355
keys = list(set(self.out_keys + list(tensordict.keys(True, True))))
1339-
return tensordict.update(result, keys_to_update=keys)
1356+
return tensordict_out.update(result, keys_to_update=keys)
13401357
return result
13411358

13421359

test/test_nn.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
)
5757

5858
from torch import distributions, nn
59-
from torch.distributions import Normal
59+
from torch.distributions import Categorical, Normal
6060
from torch.utils._pytree import tree_map
6161

6262
try:
@@ -2127,6 +2127,50 @@ def test_module_buffer():
21272127

21282128

21292129
class TestProbabilisticTensorDictModule:
2130+
@set_composite_lp_aggregate(False)
2131+
@pytest.mark.parametrize("inplace", [True, False, None])
2132+
@pytest.mark.parametrize("module_inplace", [True, False])
2133+
def test_tdprobseq_inplace(self, inplace, module_inplace):
2134+
model = ProbabilisticTensorDictSequential(
2135+
TensorDictModule(
2136+
lambda x: (x + 1, x - 1),
2137+
in_keys=["input"],
2138+
out_keys=[("intermediate", "0"), ("intermediate", "1")],
2139+
inplace=module_inplace,
2140+
),
2141+
TensorDictModule(
2142+
lambda y0, y1: y0 * y1,
2143+
in_keys=[("intermediate", "0"), ("intermediate", "1")],
2144+
out_keys=["output"],
2145+
inplace=module_inplace,
2146+
),
2147+
ProbabilisticTensorDictModule(
2148+
in_keys={"logits": "output"},
2149+
out_keys=["sample"],
2150+
return_log_prob=True,
2151+
distribution_class=Categorical,
2152+
),
2153+
inplace=inplace,
2154+
)
2155+
input = TensorDict(input=torch.zeros((5,)))
2156+
output = model(input)
2157+
assert "sample_log_prob" in output
2158+
assert "sample" in output
2159+
if inplace:
2160+
assert output is input
2161+
assert "input" in output
2162+
else:
2163+
if not module_inplace or inplace is False:
2164+
# In this case, inplace=False and inplace=None have the same behavior
2165+
assert output is not input, (module_inplace, inplace)
2166+
assert "input" not in output, (module_inplace, inplace)
2167+
else:
2168+
# In this case, inplace=False and inplace=None have the same behavior
2169+
assert output is input, (module_inplace, inplace)
2170+
assert "input" in output, (module_inplace, inplace)
2171+
2172+
assert "output" in output
2173+
21302174
@pytest.mark.parametrize("return_log_prob", [True, False])
21312175
@set_composite_lp_aggregate(False)
21322176
def test_probabilistic_n_samples(self, return_log_prob):

0 commit comments

Comments
 (0)