Skip to content

Commit eec4f53

Browse files
committed
[Compile] Reduce guards and recompiles for TensorDict under torch.compile
Four optimizations to make tensordict more compile-friendly: 1. Per-field property descriptors for tensor_only TensorClasses: bypasses generic __getattr__ dispatch, reducing Dynamo guards on attribute access. 2. Lighter clone(recurse=False): uses dict.update() fast path under is_compiling() to avoid dict comprehension and _clone_value overhead. 3. Lighter UnbatchedTensor.clone(): bypasses TensorClass.__init__ tracing under compile via __new__ + direct attribute setting. 4. allow_in_graph wrapper for _foreach_copy_ in update_(): treats the bulk copy as a single graph node, reducing per-tensor guards. Includes guard-count tests in TestGuardCount to verify no recompilation after warm-up for each optimized path. Made-with: Cursor
1 parent c7d9cc0 commit eec4f53

6 files changed

Lines changed: 316 additions & 2 deletions

File tree

benchmarks/compile/compile_td_test.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,122 @@ def test_compile_replace(mode, variant, benchmark):
400400
benchmark(func, s)
401401

402402

403+
404+
# ── Attribute-access benchmarks ──────────────────────────────────────────
405+
406+
407+
@tensorclass(tensor_only=True)
408+
class BigTC20:
409+
f0: torch.Tensor
410+
f1: torch.Tensor
411+
f2: torch.Tensor
412+
f3: torch.Tensor
413+
f4: torch.Tensor
414+
f5: torch.Tensor
415+
f6: torch.Tensor
416+
f7: torch.Tensor
417+
f8: torch.Tensor
418+
f9: torch.Tensor
419+
f10: torch.Tensor
420+
f11: torch.Tensor
421+
f12: torch.Tensor
422+
f13: torch.Tensor
423+
f14: torch.Tensor
424+
f15: torch.Tensor
425+
f16: torch.Tensor
426+
f17: torch.Tensor
427+
f18: torch.Tensor
428+
f19: torch.Tensor
429+
430+
431+
def _get_big_tc20():
432+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
433+
kwargs = {f"f{i}": torch.randn(4, device=device) for i in range(20)}
434+
return BigTC20(**kwargs, batch_size=[4], device=device)
435+
436+
437+
def tc_getattr_sum(tc):
438+
total = tc.f0
439+
for i in range(1, 20):
440+
total = total + getattr(tc, f"f{i}")
441+
return total
442+
443+
444+
@pytest.mark.skipif(
445+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
446+
)
447+
@pytest.mark.parametrize("mode", ["eager", "compile"])
448+
def test_compile_tc_getattr_20(mode, benchmark):
449+
func = tc_getattr_sum
450+
if mode == "compile":
451+
func = torch.compile(func, fullgraph=True, mode="reduce-overhead")
452+
tc = _get_big_tc20()
453+
func(tc)
454+
func(tc)
455+
benchmark(func, tc)
456+
457+
458+
# ── Shallow clone benchmarks ────────────────────────────────────────────
459+
460+
def clone_shallow(td):
461+
return td.clone(recurse=False)
462+
463+
464+
def _get_flat_td_n(n):
465+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
466+
return TensorDict(
467+
{f"k{i}": torch.randn(4, device=device) for i in range(n)},
468+
batch_size=[4],
469+
device=device,
470+
)
471+
472+
473+
@pytest.mark.skipif(
474+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
475+
)
476+
@pytest.mark.parametrize("mode", ["eager", "compile"])
477+
@pytest.mark.parametrize("n_fields", [20, 40, 80])
478+
def test_compile_clone_shallow(mode, n_fields, benchmark):
479+
td = _get_flat_td_n(n_fields)
480+
func = clone_shallow
481+
if mode == "compile":
482+
func = torch.compile(func, fullgraph=True, mode="reduce-overhead")
483+
func(td)
484+
func(td)
485+
benchmark(func, td)
486+
487+
488+
# ── update_ benchmarks ──────────────────────────────────────────────────
489+
490+
def update_inplace(td, src):
491+
td.update_(src)
492+
return td
493+
494+
495+
@pytest.mark.skipif(
496+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
497+
)
498+
@pytest.mark.parametrize("mode", ["eager", "compile"])
499+
def test_compile_update_inplace(mode, benchmark):
500+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
501+
td = TensorDict(
502+
{f"k{i}": torch.randn(4, device=device) for i in range(20)},
503+
batch_size=[4],
504+
device=device,
505+
)
506+
src = TensorDict(
507+
{f"k{i}": torch.ones(4, device=device) for i in range(20)},
508+
batch_size=[4],
509+
device=device,
510+
)
511+
func = update_inplace
512+
if mode == "compile":
513+
func = torch.compile(func, fullgraph=True, mode="reduce-overhead")
514+
func(td, src)
515+
func(td, src)
516+
benchmark(func, td, src)
517+
518+
403519
if __name__ == "__main__":
404520
args, unknown = argparse.ArgumentParser().parse_known_args()
405521
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

tensordict/_td.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3397,6 +3397,11 @@ def _clone(self, recurse: bool = True) -> Self:
33973397
if recurse and self.device is not None:
33983398
return self._clone_recurse()
33993399

3400+
if not recurse and is_compiling():
3401+
result = TensorDict(batch_size=self.batch_size, device=self.device)
3402+
result._tensordict.update(self._tensordict)
3403+
return result
3404+
34003405
result = self._new_unsafe(
34013406
source={key: _clone_value(value, recurse) for key, value in self.items()},
34023407
batch_size=self.batch_size,

tensordict/_unbatched.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@
99
from typing import Any, Callable, TYPE_CHECKING
1010

1111
import torch
12+
from tensordict._td import TensorDict
1213
from tensordict._tensorcollection import TensorCollection
1314
from tensordict.base import TensorDictBase
1415

16+
try:
17+
from torch.compiler import is_compiling
18+
except ImportError:
19+
from torch._dynamo import is_compiling
20+
1521
from tensordict.tensorclass import (
1622
_arg_to_tensordict,
1723
_from_tensordict_with_copy,
@@ -438,6 +444,13 @@ def flatten(self, start_dim: int = 0, end_dim=-1): ...
438444
def clone(self, recurse: bool = True):
439445
"""Clones the UnbatchedTensor, preserving the batch_size."""
440446
data = self.data.clone() if recurse else self.data
447+
if is_compiling():
448+
result = UnbatchedTensor.__new__(UnbatchedTensor)
449+
td = TensorDict(source={"data": data}, batch_size=[])
450+
td._batch_size = self.batch_size
451+
object.__setattr__(result, "_tensordict", td)
452+
object.__setattr__(result, "_non_tensordict", {})
453+
return result
441454
result = type(self)(data=data)
442455
result.batch_size = self.batch_size
443456
return result

tensordict/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@
127127
except ImportError:
128128
_foreach_copy_ = None
129129

130+
try:
131+
from torch.compiler import allow_in_graph as _allow_in_graph
132+
except (ImportError, AttributeError):
133+
_allow_in_graph = None
134+
135+
if _foreach_copy_ is not None and _allow_in_graph is not None:
136+
_foreach_copy_compiled = _allow_in_graph(_foreach_copy_)
137+
else:
138+
_foreach_copy_compiled = _foreach_copy_
139+
130140
try:
131141
from torch.nn.parameter import Buffer
132142
except ImportError:
@@ -8295,7 +8305,12 @@ def inplace_update(name, source, dest):
82958305
if len(other_val) != len(vals):
82968306
vals = dict(zip(keys, vals))
82978307
vals = [vals[k] for k in new_keys]
8298-
_foreach_copy_(vals, other_val, non_blocking=non_blocking)
8308+
copy_fn = (
8309+
_foreach_copy_compiled
8310+
if is_compiling()
8311+
else _foreach_copy_
8312+
)
8313+
copy_fn(vals, other_val, non_blocking=non_blocking)
82998314
return self
83008315
named = True
83018316

tensordict/tensorclass.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,27 @@ def __torch_function__(
995995
delattr(cls, field.name)
996996
except AttributeError:
997997
pass
998+
999+
if tensor_only:
1000+
for field in cls.fields():
1001+
name = field.name
1002+
1003+
def _make_prop(key):
1004+
def _getter(self):
1005+
out = self._tensordict._get_str(key, _UNSET)
1006+
if out is _UNSET:
1007+
out = self._non_tensordict.get(key, _UNSET)
1008+
if out is _UNSET:
1009+
raise AttributeError(key)
1010+
return out
1011+
if _is_unbatched(out):
1012+
return out.data
1013+
return out
1014+
1015+
return property(_getter)
1016+
1017+
setattr(cls, name, _make_prop(name))
1018+
9981019
_get_type_hints(cls, tensor_only=tensor_only)
9991020
# Detect user-defined __setattr__ that must be called during init.
10001021
# After dataclass(), frozen=True adds a guard __setattr__, which is not
@@ -1903,7 +1924,13 @@ def _setattr_tensor_only(self, key: str, value: Any) -> None: # noqa: D417
19031924
or "_non_tensordict" not in __dict__
19041925
or (
19051926
not self._shadow
1906-
and (key in SET_ATTRIBUTES or key in type(self).__dict__)
1927+
and (
1928+
key in SET_ATTRIBUTES
1929+
or (
1930+
key in type(self).__dict__
1931+
and key not in self.__expected_keys__
1932+
)
1933+
)
19071934
)
19081935
):
19091936
return self.__setattr_parent__(key, value)

test/test_compile.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
TensorDictSequential as Seq,
3838
)
3939

40+
from tensordict._unbatched import UnbatchedTensor
4041
from tensordict.nn.functional_modules import _exclude_td_from_pytree
4142

4243
from tensordict.tensorclass import TensorClass
4344

45+
from torch._dynamo.testing import CompileCounterWithBackend
4446
from torch.utils._pytree import SUPPORTED_NODES, tree_map
4547

4648
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
@@ -1623,6 +1625,142 @@ def fn(a):
16231625
torch.testing.assert_close(result, inp * 2)
16241626

16251627

1628+
def _count_compiles(fn, *args):
1629+
"""Compile fn, run it twice, return (frame_count_first, frame_count_second).
1630+
1631+
Uses CompileCounterWithBackend("eager") so the function actually executes.
1632+
"""
1633+
torch._dynamo.reset_code_caches()
1634+
cnt = CompileCounterWithBackend("eager")
1635+
compiled = torch.compile(fn, backend=cnt)
1636+
compiled(*args)
1637+
first = cnt.frame_count
1638+
compiled(*args)
1639+
second = cnt.frame_count
1640+
return first, second
1641+
1642+
1643+
@pytest.mark.skipif(
1644+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
1645+
)
1646+
class TestGuardCount:
1647+
"""Tests that verify compile guard/recompile counts for optimized paths."""
1648+
1649+
def test_clone_recurse_false_no_recompile(self):
1650+
def fn(td):
1651+
c = td.clone(recurse=False)
1652+
return c["a"] + 1
1653+
1654+
td = TensorDict(
1655+
{"a": torch.randn(4), **{f"key_{i}": torch.randn(4) for i in range(20)}},
1656+
batch_size=[4],
1657+
)
1658+
first, second = _count_compiles(fn, td)
1659+
assert first == 1, f"Expected 1 compile frame, got {first}"
1660+
assert second == 1, f"Recompilation detected: {second} frames"
1661+
1662+
def test_tc_getattr_no_recompile(self):
1663+
class BigTC(TensorClass["nocast"]):
1664+
a: torch.Tensor
1665+
b: torch.Tensor
1666+
c: torch.Tensor
1667+
d: torch.Tensor
1668+
e: torch.Tensor
1669+
1670+
def fn(tc):
1671+
return tc.a + tc.b + tc.c + tc.d + tc.e
1672+
1673+
tc = BigTC(
1674+
a=torch.randn(4),
1675+
b=torch.randn(4),
1676+
c=torch.randn(4),
1677+
d=torch.randn(4),
1678+
e=torch.randn(4),
1679+
batch_size=[4],
1680+
)
1681+
first, second = _count_compiles(fn, tc)
1682+
assert first == 1, f"Expected 1 compile frame, got {first}"
1683+
assert second == 1, f"Recompilation detected: {second} frames"
1684+
1685+
def test_replace_no_recompile(self):
1686+
class State(TensorClass["nocast"]):
1687+
x: torch.Tensor
1688+
y: torch.Tensor
1689+
z: torch.Tensor
1690+
1691+
def fn(s):
1692+
s = s.replace(x=s.x + 1)
1693+
s = s.replace(y=s.y + 2)
1694+
s = s.replace(x=s.x + s.y, z=s.z + 1)
1695+
return s
1696+
1697+
s = State(
1698+
x=torch.randn(4),
1699+
y=torch.randn(4),
1700+
z=torch.randn(4),
1701+
batch_size=[4],
1702+
)
1703+
first, second = _count_compiles(fn, s)
1704+
assert first == 1, f"Expected 1 compile frame, got {first}"
1705+
assert second == 1, f"Recompilation detected: {second} frames"
1706+
1707+
def test_update_inplace_no_recompile(self):
1708+
def fn(td, src):
1709+
td.update_(src)
1710+
return td["a"] + 0
1711+
1712+
td = TensorDict(
1713+
{"a": torch.randn(4), "b": torch.randn(4)},
1714+
batch_size=[4],
1715+
)
1716+
src = TensorDict(
1717+
{"a": torch.ones(4), "b": torch.ones(4)},
1718+
batch_size=[4],
1719+
)
1720+
first, second = _count_compiles(fn, td, src)
1721+
assert first == 1, f"Expected 1 compile frame, got {first}"
1722+
assert second == 1, f"Recompilation detected: {second} frames"
1723+
1724+
def test_unbatched_clone_no_recompile(self):
1725+
def fn(td):
1726+
c = td.clone()
1727+
return c["a"] + 0
1728+
1729+
td = TensorDict(
1730+
{
1731+
"a": torch.randn(4, 3),
1732+
"unbatched": UnbatchedTensor(data=torch.randn(5)),
1733+
},
1734+
batch_size=[4],
1735+
)
1736+
first, second = _count_compiles(fn, td)
1737+
assert first == 1, f"Expected 1 compile frame, got {first}"
1738+
assert second == 1, f"Recompilation detected: {second} frames"
1739+
1740+
def test_unbatched_clone_preserves_semantics(self):
1741+
"""Cloning an UnbatchedTensor must produce independent data."""
1742+
torch._dynamo.reset_code_caches()
1743+
1744+
def fn(td):
1745+
cloned = td.clone()
1746+
return cloned
1747+
1748+
td = TensorDict(
1749+
{
1750+
"a": torch.randn(4, 3),
1751+
"unbatched": UnbatchedTensor(data=torch.randn(5)),
1752+
},
1753+
batch_size=[4],
1754+
)
1755+
fn_c = torch.compile(fn, fullgraph=True)
1756+
result = fn_c(td)
1757+
ut_orig = td.get("unbatched")
1758+
ut_clone = result.get("unbatched")
1759+
assert ut_clone.data.data_ptr() != ut_orig.data.data_ptr(), (
1760+
"clone() must produce independent data"
1761+
)
1762+
1763+
16261764
if __name__ == "__main__":
16271765
args, unknown = argparse.ArgumentParser().parse_known_args()
16281766
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)