4141from warnings import warn
4242
4343import numpy as np
44- import orjson as json
44+
4545import tensordict as tensordict_lib
4646
4747import torch
5353from tensordict .base import (
5454 _ACCEPTED_CLASSES ,
5555 _GET_DEFAULTS_TO_NONE ,
56+ _is_leaf_nontensor ,
5657 _is_tensor_collection ,
5758 _register_tensor_class ,
5859 CompatibleType ,
7980from torch .multiprocessing import Manager
8081from torch .utils ._pytree import tree_map
8182
83+ try :
84+ import orjson as json
85+ except ImportError :
86+ # Fallback for 3.13
87+ import json
8288try :
8389 from torch .compiler import is_compiling
8490except ImportError : # torch 2.0
@@ -1841,7 +1847,10 @@ def _update(
18411847 non_blocking : bool = False ,
18421848 update_batch_size : bool = False ,
18431849 ignore_lock : bool = False ,
1850+ is_leaf : Callable [[Type ], bool ] | None = None ,
18441851):
1852+ if is_leaf is None :
1853+ is_leaf = _is_leaf_nontensor
18451854 if isinstance (input_dict_or_td , dict ):
18461855 input_dict_or_td = self .from_dict (input_dict_or_td , auto_batch_size = False )
18471856
@@ -1859,6 +1868,7 @@ def _update(
18591868 non_blocking = non_blocking ,
18601869 update_batch_size = update_batch_size ,
18611870 ignore_lock = ignore_lock ,
1871+ is_leaf = is_leaf ,
18621872 )
18631873 self ._non_tensordict .update (non_tensordict )
18641874 return self
@@ -1871,6 +1881,7 @@ def _update(
18711881 non_blocking = non_blocking ,
18721882 update_batch_size = update_batch_size ,
18731883 ignore_lock = ignore_lock ,
1884+ is_leaf = is_leaf ,
18741885 )
18751886 # We also need to remove things from non_tensordict
18761887 if self ._non_tensordict :
0 commit comments