diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ccf28cad4..adb417ad1 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -820,6 +820,8 @@ class ProbabilisticTensorDictSequential(TensorDictSequential): Otherwise, only the last module will be used to build the distribution. Defaults to ``True`` whenever there are more than one probabilistic modules or the last module is not probabilistic. Errors if `return_composite` is `False` and the neither of the above conditions are met. + selected_out_keys (iterable of NestedKeys, optional): the list of out-keys to select. If not provided, all + ``out_keys`` will be written. inplace (bool, optional): if `True`, the input tensordict is modified in-place. If `False`, a new empty :class:`~tensordict.TensorDict` instance is created. If `"empty"`, `input.empty()` is used instead (ie, the output preserves type, device and batch-size). Defaults to `None` (relies on sub-modules). @@ -958,6 +960,7 @@ def __init__( partial_tolerant: bool = False, return_composite: bool | None = None, *, + selected_out_keys: List[NestedKey] | None = None, inplace: bool | None = None, ) -> None: ... @@ -968,6 +971,7 @@ def __init__( partial_tolerant: bool = False, return_composite: bool | None = None, *, + selected_out_keys: List[NestedKey] | None = None, inplace: bool | None = None, ) -> None: ... @@ -976,6 +980,7 @@ def __init__( *modules: TensorDictModuleBase | ProbabilisticTensorDictModule, partial_tolerant: bool = False, return_composite: bool | None = None, + selected_out_keys: List[NestedKey] | None = None, inplace: bool | None = None, ) -> None: if len(modules) == 0: @@ -998,6 +1003,7 @@ def __init__( *modules, partial_tolerant=partial_tolerant, return_composite=return_composite, + selected_out_keys=selected_out_keys, inplace=inplace, ) else: @@ -1028,9 +1034,19 @@ def __init__( ) for m in modules_list ): - return TensorDictSequential( - modules, partial_tolerant=partial_tolerant, inplace=inplace + # No probabilistic modules - initialize as a regular TensorDictSequential + TensorDictSequential.__init__( + self, + *modules, + partial_tolerant=partial_tolerant, + selected_out_keys=selected_out_keys, + inplace=inplace, ) + self._requires_sample = False + self.__dict__["_det_part"] = None + # Use return_composite=True so forward() just iterates through modules + self.return_composite = True + return # if the modules not including the final probabilistic module return the sampled # key we won't be sampling it again, in that case @@ -1048,7 +1064,12 @@ def __init__( else: self.__dict__["_det_part"] = TensorDictSequential(*modules[:-1]) - super().__init__(*modules, partial_tolerant=partial_tolerant, inplace=inplace) + super().__init__( + *modules, + partial_tolerant=partial_tolerant, + selected_out_keys=selected_out_keys, + inplace=inplace, + ) self.return_composite = return_composite def __getitem__(self, index: int | slice | str) -> Self | TensorDictModuleBase: diff --git a/test/test_nn.py b/test/test_nn.py index 4260624cb..a8ed4efc3 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1256,6 +1256,98 @@ def test_key_exclusion_constructor_exec(self): assert "foo1" not in out assert out["key2"] == 1 + def test_prob_key_exclusion_constructor(self): + module1 = TensorDictModule( + nn.Linear(3, 4), in_keys=["key1", "key2"], out_keys=["foo1"] + ) + module2 = TensorDictModule( + nn.Linear(3, 4), in_keys=["key1", "key3"], out_keys=["key1"] + ) + prob_module = ProbabilisticTensorDictModule( + in_keys=["foo1", "key3"], + out_keys=["key2"], + distribution_class=Normal, + ) + seq = TensorDictSequential( + module1, module2, prob_module, selected_out_keys=["key2"] + ) + assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3"))) + assert seq.out_keys == ["key2"] + + def test_prob_key_exclusion_constructor_exec(self): + module1 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"] + ) + module2 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key3"], out_keys=["key1"] + ) + prob_module = ProbabilisticTensorDictModule( + in_keys={"loc": "foo1", "scale": "key3"}, + out_keys=["key2"], + distribution_class=Normal, + ) + seq = ProbabilisticTensorDictSequential( + module1, module2, prob_module, selected_out_keys=["key2"] + ) + assert set(seq.in_keys) == set(unravel_key_list(("key1", "key2", "key3"))) + assert seq.out_keys == ["key2"] + td = TensorDict(key1=0, key2=0, key3=1) + out = seq(td) + assert out is td + assert "key1" in out + assert "key2" in out + assert "key3" in out + assert "foo1" not in out + + def test_prob_seq_no_prob_modules(self): + """Test ProbabilisticTensorDictSequential with no probabilistic modules. + + When no probabilistic modules are passed, it should initialize as a regular + TensorDictSequential and still work correctly, including with selected_out_keys. + """ + module1 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"] + ) + module2 = TensorDictModule( + lambda x, y: x + y, in_keys=["foo1", "key3"], out_keys=["result"] + ) + # No probabilistic modules - should still work + seq = ProbabilisticTensorDictSequential(module1, module2) + assert set(seq.in_keys) == {"key1", "key2", "key3"} + assert set(seq.out_keys) == {"foo1", "result"} + + # Verify internal attributes are set correctly for no-prob-modules case + assert seq._requires_sample is False + assert seq._det_part is None + assert seq.return_composite is True + + td = TensorDict(key1=1, key2=2, key3=3) + out = seq(td) + assert out is td + assert out["foo1"] == 3 # 1 + 2 + assert out["result"] == 6 # 3 + 3 + + def test_prob_seq_no_prob_modules_selected_out_keys(self): + """Test ProbabilisticTensorDictSequential with no prob modules and selected_out_keys.""" + module1 = TensorDictModule( + lambda x, y: x + y, in_keys=["key1", "key2"], out_keys=["foo1"] + ) + module2 = TensorDictModule( + lambda x, y: x + y, in_keys=["foo1", "key3"], out_keys=["result"] + ) + # No probabilistic modules, but with selected_out_keys + seq = ProbabilisticTensorDictSequential( + module1, module2, selected_out_keys=["result"] + ) + assert set(seq.in_keys) == {"key1", "key2", "key3"} + assert seq.out_keys == ["result"] + + td = TensorDict(key1=1, key2=2, key3=3) + out = seq(td) + assert out is td + assert "foo1" not in out # Should be excluded + assert out["result"] == 6 # 3 + 3 + @pytest.mark.parametrize("lazy", [True, False]) def test_stateful(self, lazy): torch.manual_seed(0)