@@ -218,7 +218,10 @@ def __init__(
218218 ** {key : val for key , val in _zip_strict (modules [0 ], modules_vals )}
219219 )
220220 super ().__init__ (
221- module = nn .ModuleDict (modules ), in_keys = in_keys , out_keys = out_keys
221+ module = nn .ModuleDict (modules ),
222+ in_keys = in_keys ,
223+ out_keys = out_keys ,
224+ inplace = inplace ,
222225 )
223226 elif len (modules ) == 1 and isinstance (
224227 modules [0 ], collections .abc .MutableSequence
@@ -227,20 +230,27 @@ def __init__(
227230 in_keys , out_keys = self ._compute_in_and_out_keys (modules )
228231 self ._complete_out_keys = list (out_keys )
229232 super ().__init__ (
230- module = nn .ModuleList (modules ), in_keys = in_keys , out_keys = out_keys
233+ module = nn .ModuleList (modules ),
234+ in_keys = in_keys ,
235+ out_keys = out_keys ,
236+ inplace = inplace ,
231237 )
232238 elif len (modules ) == 1 and isinstance (modules [0 ], dict ):
233239 return self .__init__ (
234240 collections .OrderedDict (modules [0 ]),
235241 partial_tolerant = partial_tolerant ,
236242 selected_out_keys = selected_out_keys ,
243+ inplace = inplace ,
237244 )
238245 else :
239246 modules = self ._convert_modules (modules )
240247 in_keys , out_keys = self ._compute_in_and_out_keys (modules )
241248 self ._complete_out_keys = list (out_keys )
242249 super ().__init__ (
243- module = nn .ModuleList (list (modules )), in_keys = in_keys , out_keys = out_keys
250+ module = nn .ModuleList (list (modules )),
251+ in_keys = in_keys ,
252+ out_keys = out_keys ,
253+ inplace = inplace ,
244254 )
245255
246256 self .inplace = inplace
@@ -628,6 +638,7 @@ def forward(
628638 )
629639 if tensordict_out is not None :
630640 result = tensordict_out
641+ print ('here! update' )
631642 result .update (tensordict_exec , keys_to_update = self .out_keys )
632643 else :
633644 result = tensordict_exec
0 commit comments