Skip to content

Commit c61d045

Browse files
author
Vincent Moens
committed
[BugFix] Better prop of args in update
ghstack-source-id: 38859ac Pull Request resolved: #1290
1 parent 2df09c9 commit c61d045

6 files changed

Lines changed: 52 additions & 17 deletions

File tree

tensordict/_lazy.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,7 @@
3030

3131
import numpy as np
3232

33-
import orjson as json
3433
import torch
35-
36-
from tensordict.memmap import MemoryMappedTensor
37-
from torch.nn.utils.rnn import pad_sequence
38-
39-
try:
40-
from functorch import dim as ftdim
41-
42-
_has_funcdim = True
43-
except ImportError:
44-
from tensordict.utils import _ftdim_mock as ftdim
45-
46-
_has_funcdim = False
4734
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
4835
from tensordict.base import (
4936
_is_leaf_nontensor,
@@ -58,6 +45,8 @@
5845
T,
5946
TensorDictBase,
6047
)
48+
49+
from tensordict.memmap import MemoryMappedTensor
6150
from tensordict.utils import (
6251
_as_context_manager,
6352
_broadcast_tensors,
@@ -90,7 +79,22 @@
9079
unravel_key_list,
9180
)
9281
from torch import Tensor
82+
from torch.nn.utils.rnn import pad_sequence
9383

84+
try:
85+
import orjson as json
86+
except ImportError:
87+
# Fallback
88+
import json
89+
90+
try:
91+
from functorch import dim as ftdim
92+
93+
_has_funcdim = True
94+
except ImportError:
95+
from tensordict.utils import _ftdim_mock as ftdim
96+
97+
_has_funcdim = False
9498

9599
_has_functorch = False
96100
try:
@@ -2943,6 +2947,7 @@ def update(
29432947
keys_to_update=keys_to_update,
29442948
non_blocking=non_blocking,
29452949
is_leaf=is_leaf,
2950+
update_batch_size=update_batch_size,
29462951
**kwargs,
29472952
)
29482953
return self

tensordict/_td.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from warnings import warn
1919

2020
import numpy as np
21-
import orjson as json
21+
2222
import torch
2323
from tensordict._nestedkey import NestedKey
2424

@@ -91,6 +91,11 @@
9191
from torch.nn.utils._named_member_accessor import swap_tensor
9292
from torch.utils._pytree import tree_map
9393

94+
try:
95+
import orjson as json
96+
except ImportError:
97+
# Fallback
98+
import json
9499
try:
95100
from functorch import dim as ftdim
96101

tensordict/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
)
4747

4848
import numpy as np
49-
import orjson as json
49+
5050
import torch
5151

5252
from tensordict._contextlib import LAST_OP_MAPS
@@ -110,6 +110,12 @@
110110
from torch.nn.parameter import Parameter, UninitializedTensorMixin
111111
from torch.utils._pytree import tree_map
112112

113+
try:
114+
import orjson as json
115+
except ImportError:
116+
# Fallback for 3.13
117+
import json
118+
113119
try:
114120
from torch.compiler import is_compiling
115121
except ImportError: # torch 2.0

tensordict/persistent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import numpy as np
2020

21-
import orjson as json
2221
import torch
2322

2423
from tensordict._td import (
@@ -57,6 +56,12 @@
5756
)
5857
from torch import multiprocessing as mp
5958

59+
try:
60+
import orjson as json
61+
except ImportError:
62+
# Fallback for 3.13
63+
import json
64+
6065
_has_h5 = importlib.util.find_spec("h5py", None) is not None
6166

6267

tensordict/tensorclass.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from warnings import warn
4242

4343
import numpy as np
44-
import orjson as json
44+
4545
import tensordict as tensordict_lib
4646

4747
import torch
@@ -53,6 +53,7 @@
5353
from tensordict.base import (
5454
_ACCEPTED_CLASSES,
5555
_GET_DEFAULTS_TO_NONE,
56+
_is_leaf_nontensor,
5657
_is_tensor_collection,
5758
_register_tensor_class,
5859
CompatibleType,
@@ -79,6 +80,11 @@
7980
from torch.multiprocessing import Manager
8081
from torch.utils._pytree import tree_map
8182

83+
try:
84+
import orjson as json
85+
except ImportError:
86+
# Fallback for 3.13
87+
import json
8288
try:
8389
from torch.compiler import is_compiling
8490
except ImportError: # torch 2.0
@@ -1841,7 +1847,10 @@ def _update(
18411847
non_blocking: bool = False,
18421848
update_batch_size: bool = False,
18431849
ignore_lock: bool = False,
1850+
is_leaf: Callable[[Type], bool] | None = None,
18441851
):
1852+
if is_leaf is None:
1853+
is_leaf = _is_leaf_nontensor
18451854
if isinstance(input_dict_or_td, dict):
18461855
input_dict_or_td = self.from_dict(input_dict_or_td, auto_batch_size=False)
18471856

@@ -1859,6 +1868,7 @@ def _update(
18591868
non_blocking=non_blocking,
18601869
update_batch_size=update_batch_size,
18611870
ignore_lock=ignore_lock,
1871+
is_leaf=is_leaf,
18621872
)
18631873
self._non_tensordict.update(non_tensordict)
18641874
return self
@@ -1871,6 +1881,7 @@ def _update(
18711881
non_blocking=non_blocking,
18721882
update_batch_size=update_batch_size,
18731883
ignore_lock=ignore_lock,
1884+
is_leaf=is_leaf,
18741885
)
18751886
# We also need to remove things from non_tensordict
18761887
if self._non_tensordict:

tensordict/tensorclass.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,3 +1119,6 @@ class TensorClass:
11191119
def uint32(self) -> T: ...
11201120
def uint64(self) -> T: ...
11211121
def uint8(self) -> T: ...
1122+
1123+
class NonTensorData(TensorClass): ...
1124+
class NonTensorStack(TensorDictBase): ...

0 commit comments

Comments
 (0)