Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,8 @@ class ProbabilisticTensorDictSequential(TensorDictSequential):
Otherwise, only the last module will be used to build the distribution.
Defaults to ``True`` whenever there are more than one probabilistic modules or the last module is not probabilistic.
Errors if `return_composite` is `False` and the neither of the above conditions are met.
selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all
``out_keys`` will be written.
inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty
:class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the
output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules).
Expand Down Expand Up @@ -958,6 +960,7 @@ def __init__(
partial_tolerant: bool = False,
return_composite: bool | None = None,
*,
selected_out_keys: List[NestedKey] | None = None,
inplace: bool | None = None,
) -> None: ...

Expand All @@ -968,6 +971,7 @@ def __init__(
partial_tolerant: bool = False,
return_composite: bool | None = None,
*,
selected_out_keys: List[NestedKey] | None = None,
inplace: bool | None = None,
) -> None: ...

Expand All @@ -976,6 +980,7 @@ def __init__(
*modules: TensorDictModuleBase | ProbabilisticTensorDictModule,
partial_tolerant: bool = False,
return_composite: bool | None = None,
selected_out_keys: List[NestedKey] | None = None,
inplace: bool | None = None,
) -> None:
if len(modules) == 0:
Expand All @@ -998,6 +1003,7 @@ def __init__(
*modules,
partial_tolerant=partial_tolerant,
return_composite=return_composite,
selected_out_keys=selected_out_keys,
inplace=inplace,
)
else:
Expand Down Expand Up @@ -1028,9 +1034,19 @@ def __init__(
)
for m in modules_list
):
return TensorDictSequential(
modules, partial_tolerant=partial_tolerant, inplace=inplace
# No probabilistic modules - initialize as a regular TensorDictSequential
TensorDictSequential.__init__(
self,
*modules,
partial_tolerant=partial_tolerant,
selected_out_keys=selected_out_keys,
inplace=inplace,
)
self._requires_sample = False
self.__dict__["_det_part"] = None
# Use return_composite=True so forward() just iterates through modules
self.return_composite = True
return

# if the modules not including the final probabilistic module return the sampled
# key we won't be sampling it again, in that case
Expand All @@ -1048,7 +1064,12 @@ def __init__(
else:
self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1])

super().__init__(*modules, partial_tolerant=partial_tolerant, inplace=inplace)
super().__init__(
*modules,
partial_tolerant=partial_tolerant,
selected_out_keys=selected_out_keys,
inplace=inplace,
)
self.return_composite = return_composite

def __getitem__(self, index: int | slice | str) -> Self | TensorDictModuleBase:
Expand Down
92 changes: 92 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,98 @@ def test_key_exclusion_constructor_exec(self):
assert "foo1" not in out
assert out["key2"] == 1

def test_prob_key_exclusion_constructor(self):
module1 = TensorDictModule(
nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"]
)
module2 = TensorDictModule(
nn.Linear(3, 4), in_keys=["key1", "key3"], out_keys=["key1"]
)
prob_module = ProbabilisticTensorDictModule(
in_keys=["foo1", "key3"],
out_keys=["key2"],
distribution_class=Normal,
)
seq = TensorDictSequential(
module1, module2, prob_module, selected_out_keys=["key2"]
)
assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3")))
assert seq.out_keys == ["key2"]

def test_prob_key_exclusion_constructor_exec(self):
module1 = TensorDictModule(
lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"]
)
module2 = TensorDictModule(
lambda x, y: x + y, in_keys=["key1", "key3"], out_keys=["key1"]
)
prob_module = ProbabilisticTensorDictModule(
in_keys={"loc": "foo1", "scale": "key3"},
out_keys=["key2"],
distribution_class=Normal,
)
seq = ProbabilisticTensorDictSequential(
module1, module2, prob_module, selected_out_keys=["key2"]
)
assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3")))
assert seq.out_keys == ["key2"]
td = TensorDict(key1=0, key2=0, key3=1)
out = seq(td)
assert out is td
assert "key1" in out
assert "key2" in out
assert "key3" in out
assert "foo1" not in out

def test_prob_seq_no_prob_modules(self):
"""Test ProbabilisticTensorDictSequential with no probabilistic modules.

When no probabilistic modules are passed, it should initialize as a regular
TensorDictSequential and still work correctly, including with selected_out_keys.
"""
module1 = TensorDictModule(
lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"]
)
module2 = TensorDictModule(
lambda x, y: x + y, in_keys=["foo1", "key3"], out_keys=["result"]
)
# No probabilistic modules - should still work
seq = ProbabilisticTensorDictSequential(module1, module2)
assert set(seq.in_keys) == {"key1", "key2", "key3"}
assert set(seq.out_keys) == {"foo1", "result"}

# Verify internal attributes are set correctly for no-prob-modules case
assert seq._requires_sample is False
assert seq._det_part is None
assert seq.return_composite is True

td = TensorDict(key1=1, key2=2, key3=3)
out = seq(td)
assert out is td
assert out["foo1"] == 3 # 1 + 2
assert out["result"] == 6 # 3 + 3

def test_prob_seq_no_prob_modules_selected_out_keys(self):
"""Test ProbabilisticTensorDictSequential with no prob modules and selected_out_keys."""
module1 = TensorDictModule(
lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"]
)
module2 = TensorDictModule(
lambda x, y: x + y, in_keys=["foo1", "key3"], out_keys=["result"]
)
# No probabilistic modules, but with selected_out_keys
seq = ProbabilisticTensorDictSequential(
module1, module2, selected_out_keys=["result"]
)
assert set(seq.in_keys) == {"key1", "key2", "key3"}
assert seq.out_keys == ["result"]

td = TensorDict(key1=1, key2=2, key3=3)
out = seq(td)
assert out is td
assert "foo1" not in out # Should be excluded
assert out["result"] == 6 # 3 + 3

@pytest.mark.parametrize("lazy", [True, False])
def test_stateful(self, lazy):
torch.manual_seed(0)
Expand Down
Loading