File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -1003,6 +1003,7 @@ def __init__(
10031003 * modules ,
10041004 partial_tolerant = partial_tolerant ,
10051005 return_composite = return_composite ,
1006+ selected_out_keys = selected_out_keys ,
10061007 inplace = inplace ,
10071008 )
10081009 else :
@@ -1033,9 +1034,19 @@ def __init__(
10331034 )
10341035 for m in modules_list
10351036 ):
1036- return TensorDictSequential (
1037- modules , partial_tolerant = partial_tolerant , inplace = inplace
1037+ # No probabilistic modules - initialize as a regular TensorDictSequential
1038+ TensorDictSequential .__init__ (
1039+ self ,
1040+ * modules ,
1041+ partial_tolerant = partial_tolerant ,
1042+ selected_out_keys = selected_out_keys ,
1043+ inplace = inplace ,
10381044 )
1045+ self ._requires_sample = False
1046+ self .__dict__ ["_det_part" ] = None
1047+ # Use return_composite=True so forward() just iterates through modules
1048+ self .return_composite = True
1049+ return
10391050
10401051 # if the modules not including the final probabilistic module return the sampled
10411052 # key we won't be sampling it again, in that case
You can’t perform that action at this time.
0 commit comments