Skip to content

Commit fdf28f6

Browse files
committed
fixes
1 parent 06f41e2 commit fdf28f6

1 file changed

Lines changed: 13 additions & 2 deletions

File tree

tensordict/nn/probabilistic.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,7 @@ def __init__(
10031003
*modules,
10041004
partial_tolerant=partial_tolerant,
10051005
return_composite=return_composite,
1006+
selected_out_keys=selected_out_keys,
10061007
inplace=inplace,
10071008
)
10081009
else:
@@ -1033,9 +1034,19 @@ def __init__(
10331034
)
10341035
for m in modules_list
10351036
):
1036-
return TensorDictSequential(
1037-
modules, partial_tolerant=partial_tolerant, inplace=inplace
1037+
# No probabilistic modules - initialize as a regular TensorDictSequential
1038+
TensorDictSequential.__init__(
1039+
self,
1040+
*modules,
1041+
partial_tolerant=partial_tolerant,
1042+
selected_out_keys=selected_out_keys,
1043+
inplace=inplace,
10381044
)
1045+
self._requires_sample = False
1046+
self.__dict__["_det_part"] = None
1047+
# Use return_composite=True so forward() just iterates through modules
1048+
self.return_composite = True
1049+
return
10391050

10401051
# if the modules not including the final probabilistic module return the sampled
10411052
# key we won't be sampling it again, in that case

0 commit comments

Comments
 (0)