@@ -926,8 +926,7 @@ def all(self, dim: int | None = None) -> bool | TensorCollection:
926926
927927 names = None
928928 if self ._has_names ():
929- names = copy (self .names )
930- names = [name for i , name in enumerate (names ) if i != dim ]
929+ names = [name for i , name in enumerate (self .names ) if i != dim ]
931930
932931 return TensorDict (
933932 source = {key : value .all (dim = dim ) for key , value in self .items ()},
@@ -948,8 +947,7 @@ def any(self, dim: int | None = None) -> bool | TensorCollection:
948947
949948 names = None
950949 if self ._has_names ():
951- names = copy (self .names )
952- names = [name for i , name in enumerate (names ) if i != dim ]
950+ names = [name for i , name in enumerate (self .names ) if i != dim ]
953951
954952 return TensorDict (
955953 source = {key : value .any (dim = dim ) for key , value in self .items ()},
@@ -1071,7 +1069,7 @@ def reduction(val):
10711069 return result
10721070
10731071 if self ._has_names ():
1074- names = copy (self .names )
1072+ names = list (self .names )
10751073 else :
10761074 names = None
10771075 if not call_on_nested :
@@ -1088,11 +1086,10 @@ def reduction(val):
10881086 elif dim is not NO_DEFAULT or keepdim :
10891087 names = None
10901088 if self ._has_names ():
1091- names = copy (self .names )
10921089 if not keepdim and isinstance (dim , tuple ):
1093- names = [name for i , name in enumerate (names ) if i not in dim ]
1090+ names = [name for i , name in enumerate (self . names ) if i not in dim ]
10941091 else :
1095- names = [name for i , name in enumerate (names ) if i != dim ]
1092+ names = [name for i , name in enumerate (self . names ) if i != dim ]
10961093 if dim is not NO_DEFAULT :
10971094 kwargs ["dim" ] = dim
10981095 if keepdim is not NO_DEFAULT :
@@ -1745,8 +1742,7 @@ def _unbind(self, dim: int):
17451742 batch_size = torch .Size ([s for i , s in enumerate (self .batch_size ) if i != dim ])
17461743 names = None
17471744 if self ._has_names ():
1748- names = copy (self .names )
1749- names = [name for i , name in enumerate (names ) if i != dim ]
1745+ names = [name for i , name in enumerate (self .names ) if i != dim ]
17501746 # We could use any() but dynamo doesn't like generators
17511747 for name in names :
17521748 if name is not None :
@@ -2072,7 +2068,7 @@ def _permute(tensor):
20722068 def _squeeze (self , dim = None ):
20732069 batch_size = self .batch_size
20742070 if dim is None :
2075- names = copy (self .names ) if self ._has_names () else None
2071+ names = list (self .names ) if self ._has_names () else None
20762072 if names is not None :
20772073 batch_size , names = _zip_strict (
20782074 * [
@@ -2114,7 +2110,7 @@ def _squeeze(tensor):
21142110 batch_size = list (batch_size )
21152111 batch_size .pop (dim )
21162112 batch_size = list (batch_size )
2117- names = copy (self .names ) if self ._has_names () else None
2113+ names = list (self .names ) if self ._has_names () else None
21182114 if names :
21192115 names .pop (dim )
21202116
@@ -2149,7 +2145,7 @@ def _unsqueeze(self, dim: int):
21492145 batch_size .insert (newdim , 1 )
21502146 batch_size = torch .Size (batch_size )
21512147
2152- names = copy (self .names ) if self ._has_names () else None
2148+ names = list (self .names ) if self ._has_names () else None
21532149 if names :
21542150 names .insert (newdim , None )
21552151
@@ -2243,7 +2239,10 @@ def from_dict_instance(
22432239 from tensordict import TensorDict
22442240
22452241 batch_size_set = torch .Size (()) if batch_size is None else batch_size
2246- input_dict = copy (input_dict )
2242+ if is_compiling ():
2243+ input_dict = type (input_dict )(input_dict )
2244+ else :
2245+ input_dict = copy (input_dict )
22472246 for key , value in list (input_dict .items ()):
22482247 if isinstance (value , (dict ,)):
22492248 cur_value = self .get (key )
0 commit comments