Skip to content

Commit b0c79fb

Browse files
author
Vincent Moens
committed
[BugFix] Fix tensorclass update
ghstack-source-id: 3f50604 Pull Request resolved: #1255
1 parent 55fab2a commit b0c79fb

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

tensordict/nn/sequence.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ def __init__(
218218
**{key: val for key, val in _zip_strict(modules[0], modules_vals)}
219219
)
220220
super().__init__(
221-
module=nn.ModuleDict(modules), in_keys=in_keys, out_keys=out_keys
221+
module=nn.ModuleDict(modules),
222+
in_keys=in_keys,
223+
out_keys=out_keys,
224+
inplace=inplace,
222225
)
223226
elif len(modules) == 1 and isinstance(
224227
modules[0], collections.abc.MutableSequence
@@ -227,20 +230,27 @@ def __init__(
227230
in_keys, out_keys = self._compute_in_and_out_keys(modules)
228231
self._complete_out_keys = list(out_keys)
229232
super().__init__(
230-
module=nn.ModuleList(modules), in_keys=in_keys, out_keys=out_keys
233+
module=nn.ModuleList(modules),
234+
in_keys=in_keys,
235+
out_keys=out_keys,
236+
inplace=inplace,
231237
)
232238
elif len(modules) == 1 and isinstance(modules[0], dict):
233239
return self.__init__(
234240
collections.OrderedDict(modules[0]),
235241
partial_tolerant=partial_tolerant,
236242
selected_out_keys=selected_out_keys,
243+
inplace=inplace,
237244
)
238245
else:
239246
modules = self._convert_modules(modules)
240247
in_keys, out_keys = self._compute_in_and_out_keys(modules)
241248
self._complete_out_keys = list(out_keys)
242249
super().__init__(
243-
module=nn.ModuleList(list(modules)), in_keys=in_keys, out_keys=out_keys
250+
module=nn.ModuleList(list(modules)),
251+
in_keys=in_keys,
252+
out_keys=out_keys,
253+
inplace=inplace,
244254
)
245255

246256
self.inplace = inplace
@@ -628,6 +638,7 @@ def forward(
628638
)
629639
if tensordict_out is not None:
630640
result = tensordict_out
641+
print('here! update')
631642
result.update(tensordict_exec, keys_to_update=self.out_keys)
632643
else:
633644
result = tensordict_exec

tensordict/tensorclass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,12 @@ def _update(
16941694
update_batch_size=update_batch_size,
16951695
ignore_lock=ignore_lock,
16961696
)
1697+
# We also need to remove things from non_tensordict
1698+
if self._non_tensordict:
1699+
keys = set(self._tensordict.keys())
1700+
ntd = {k: val for k, val in self._non_tensordict.items() if k not in keys}
1701+
self._non_tensordict.clear()
1702+
self._non_tensordict.update(ntd)
16971703
return self
16981704

16991705

0 commit comments

Comments
 (0)