Skip to content

Commit bb7d548

Browse files
authored
[Feature] More copy() refactors for compile friendliness (#1516)
1 parent e3daeb7 commit bb7d548

4 files changed

Lines changed: 69 additions & 22 deletions

File tree

tensordict/_td.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -926,8 +926,7 @@ def all(self, dim: int | None = None) -> bool | TensorCollection:
926926

927927
names = None
928928
if self._has_names():
929-
names = copy(self.names)
930-
names = [name for i, name in enumerate(names) if i != dim]
929+
names = [name for i, name in enumerate(self.names) if i != dim]
931930

932931
return TensorDict(
933932
source={key: value.all(dim=dim) for key, value in self.items()},
@@ -948,8 +947,7 @@ def any(self, dim: int | None = None) -> bool | TensorCollection:
948947

949948
names = None
950949
if self._has_names():
951-
names = copy(self.names)
952-
names = [name for i, name in enumerate(names) if i != dim]
950+
names = [name for i, name in enumerate(self.names) if i != dim]
953951

954952
return TensorDict(
955953
source={key: value.any(dim=dim) for key, value in self.items()},
@@ -1071,7 +1069,7 @@ def reduction(val):
10711069
return result
10721070

10731071
if self._has_names():
1074-
names = copy(self.names)
1072+
names = list(self.names)
10751073
else:
10761074
names = None
10771075
if not call_on_nested:
@@ -1088,11 +1086,10 @@ def reduction(val):
10881086
elif dim is not NO_DEFAULT or keepdim:
10891087
names = None
10901088
if self._has_names():
1091-
names = copy(self.names)
10921089
if not keepdim and isinstance(dim, tuple):
1093-
names = [name for i, name in enumerate(names) if i not in dim]
1090+
names = [name for i, name in enumerate(self.names) if i not in dim]
10941091
else:
1095-
names = [name for i, name in enumerate(names) if i != dim]
1092+
names = [name for i, name in enumerate(self.names) if i != dim]
10961093
if dim is not NO_DEFAULT:
10971094
kwargs["dim"] = dim
10981095
if keepdim is not NO_DEFAULT:
@@ -1745,8 +1742,7 @@ def _unbind(self, dim: int):
17451742
batch_size = torch.Size([s for i, s in enumerate(self.batch_size) if i != dim])
17461743
names = None
17471744
if self._has_names():
1748-
names = copy(self.names)
1749-
names = [name for i, name in enumerate(names) if i != dim]
1745+
names = [name for i, name in enumerate(self.names) if i != dim]
17501746
# We could use any() but dynamo doesn't like generators
17511747
for name in names:
17521748
if name is not None:
@@ -2072,7 +2068,7 @@ def _permute(tensor):
20722068
def _squeeze(self, dim=None):
20732069
batch_size = self.batch_size
20742070
if dim is None:
2075-
names = copy(self.names) if self._has_names() else None
2071+
names = list(self.names) if self._has_names() else None
20762072
if names is not None:
20772073
batch_size, names = _zip_strict(
20782074
*[
@@ -2114,7 +2110,7 @@ def _squeeze(tensor):
21142110
batch_size = list(batch_size)
21152111
batch_size.pop(dim)
21162112
batch_size = list(batch_size)
2117-
names = copy(self.names) if self._has_names() else None
2113+
names = list(self.names) if self._has_names() else None
21182114
if names:
21192115
names.pop(dim)
21202116

@@ -2149,7 +2145,7 @@ def _unsqueeze(self, dim: int):
21492145
batch_size.insert(newdim, 1)
21502146
batch_size = torch.Size(batch_size)
21512147

2152-
names = copy(self.names) if self._has_names() else None
2148+
names = list(self.names) if self._has_names() else None
21532149
if names:
21542150
names.insert(newdim, None)
21552151

@@ -2243,7 +2239,10 @@ def from_dict_instance(
22432239
from tensordict import TensorDict
22442240

22452241
batch_size_set = torch.Size(()) if batch_size is None else batch_size
2246-
input_dict = copy(input_dict)
2242+
if is_compiling():
2243+
input_dict = type(input_dict)(input_dict)
2244+
else:
2245+
input_dict = copy(input_dict)
22472246
for key, value in list(input_dict.items()):
22482247
if isinstance(value, (dict,)):
22492248
cur_value = self.get(key)

tensordict/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5560,7 +5560,7 @@ def refine_names(self, *names) -> Self:
55605560

55615561
"""
55625562
# replace ellipsis if any
5563-
names_copy = copy(names)
5563+
names_copy = list(names)
55645564
if any(name is Ellipsis for name in names):
55655565
ellipsis_name = [NO_DEFAULT for _ in range(self.ndim - len(names) + 1)]
55665566
names = []
@@ -6053,7 +6053,10 @@ def load_state_dict(
60536053
return self.update(self_flatten.unflatten_keys("."))
60546054

60556055
# copy since we'll be using pop
6056-
state_dict = copy(state_dict)
6056+
if is_compiling():
6057+
state_dict = type(state_dict)(state_dict)
6058+
else:
6059+
state_dict = copy(state_dict)
60576060
batch_size = state_dict.pop("__batch_size")
60586061
device = state_dict.pop("__device", None)
60596062

@@ -8959,7 +8962,7 @@ def unflatten(tensor):
89598962
unflatten, batch_size=batch_size, propagate_lock=True, call_on_nested=True
89608963
)
89618964
if self._has_names():
8962-
names = copy(self.names)
8965+
names = list(self.names)
89638966
for _ in range(len(unflattened_size) - 1):
89648967
names.insert(dim, None)
89658968
out.names = names

tensordict/tensorclass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def _share_memory_(self):
16611661

16621662

16631663
def _load_memmap(cls, prefix: Path, metadata: dict, *, robust_key, **kwargs):
1664-
non_tensordict = copy(metadata)
1664+
non_tensordict = dict(metadata)
16651665
del non_tensordict["_type"]
16661666
if os.path.exists(prefix / "other.pickle"):
16671667
with open(prefix / "other.pickle", "rb") as pickle_file:
@@ -2689,7 +2689,7 @@ def _state_dict(
26892689
destination=destination, prefix=prefix, keep_vars=keep_vars, flatten=flatten
26902690
)
26912691
}
2692-
state_dict["_non_tensordict"] = copy(self._non_tensordict)
2692+
state_dict["_non_tensordict"] = dict(self._non_tensordict)
26932693
return state_dict
26942694

26952695

@@ -3750,9 +3750,7 @@ def _memmap_(
37503750

37513751
_metadata = {}
37523752
if prefix is not None:
3753-
_metadata = copy(self._metadata)
3754-
if _metadata is None:
3755-
_metadata = {}
3753+
_metadata = dict(self._metadata) if self._metadata is not None else {}
37563754
_metadata["memmap_prefix"] = prefix
37573755
_metadata["memmaped"] = memmaped
37583756

test/test_compile.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,53 @@ def exclude_keys(td: TensorDict):
328328
assert "a" in result_c.keys()
329329
assert "b" not in result_c.keys()
330330

331+
def test_all_any(self, mode):
332+
def call_all(td: TensorDict):
333+
return td.all(dim=0)
334+
335+
def call_any(td: TensorDict):
336+
return td.any(dim=0)
337+
338+
call_all_c = torch.compile(call_all, fullgraph=True, mode=mode)
339+
call_any_c = torch.compile(call_any, fullgraph=True, mode=mode)
340+
341+
data = TensorDict(
342+
{"a": torch.tensor([[True, False], [True, True]])},
343+
batch_size=[2, 2],
344+
)
345+
346+
result_all = call_all(data)
347+
result_all_c = call_all_c(data)
348+
assert (result_all["a"] == result_all_c["a"]).all()
349+
assert result_all_c.shape == torch.Size([2])
350+
351+
result_any = call_any(data)
352+
result_any_c = call_any_c(data)
353+
assert (result_any["a"] == result_any_c["a"]).all()
354+
assert result_any_c.shape == torch.Size([2])
355+
356+
def test_squeeze_unsqueeze(self, mode):
357+
def call_squeeze(td: TensorDict):
358+
return td.squeeze(0)
359+
360+
def call_unsqueeze(td: TensorDict):
361+
return td.unsqueeze(0)
362+
363+
call_squeeze_c = torch.compile(call_squeeze, fullgraph=True, mode=mode)
364+
call_unsqueeze_c = torch.compile(call_unsqueeze, fullgraph=True, mode=mode)
365+
366+
data = TensorDict({"a": torch.randn(1, 3)}, batch_size=[1, 3])
367+
368+
result_squeeze = call_squeeze(data)
369+
result_squeeze_c = call_squeeze_c(data)
370+
assert result_squeeze.shape == result_squeeze_c.shape
371+
assert result_squeeze_c.shape == torch.Size([3])
372+
373+
result_unsqueeze = call_unsqueeze(result_squeeze)
374+
result_unsqueeze_c = call_unsqueeze_c(result_squeeze_c)
375+
assert result_unsqueeze.shape == result_unsqueeze_c.shape
376+
assert result_unsqueeze_c.shape == torch.Size([1, 3])
377+
331378
def test_names(self, mode):
332379
import torch._dynamo.exc
333380

0 commit comments

Comments
 (0)