|
37 | 37 | TensorDictSequential as Seq, |
38 | 38 | ) |
39 | 39 |
|
| 40 | +from tensordict._unbatched import UnbatchedTensor |
40 | 41 | from tensordict.nn.functional_modules import _exclude_td_from_pytree |
41 | 42 |
|
42 | 43 | from tensordict.tensorclass import TensorClass |
43 | 44 |
|
| 45 | +from torch._dynamo.testing import CompileCounterWithBackend |
44 | 46 | from torch.utils._pytree import SUPPORTED_NODES, tree_map |
45 | 47 |
|
46 | 48 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) |
@@ -1623,6 +1625,142 @@ def fn(a): |
1623 | 1625 | torch.testing.assert_close(result, inp * 2) |
1624 | 1626 |
|
1625 | 1627 |
|
| 1628 | +def _count_compiles(fn, *args): |
| 1629 | + """Compile fn, run it twice, return (frame_count_first, frame_count_second). |
| 1630 | +
|
| 1631 | + Uses CompileCounterWithBackend("eager") so the function actually executes. |
| 1632 | + """ |
| 1633 | + torch._dynamo.reset_code_caches() |
| 1634 | + cnt = CompileCounterWithBackend("eager") |
| 1635 | + compiled = torch.compile(fn, backend=cnt) |
| 1636 | + compiled(*args) |
| 1637 | + first = cnt.frame_count |
| 1638 | + compiled(*args) |
| 1639 | + second = cnt.frame_count |
| 1640 | + return first, second |
| 1641 | + |
| 1642 | + |
| 1643 | +@pytest.mark.skipif( |
| 1644 | + TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4" |
| 1645 | +) |
| 1646 | +class TestGuardCount: |
| 1647 | + """Tests that verify compile guard/recompile counts for optimized paths.""" |
| 1648 | + |
| 1649 | + def test_clone_recurse_false_no_recompile(self): |
| 1650 | + def fn(td): |
| 1651 | + c = td.clone(recurse=False) |
| 1652 | + return c["a"] + 1 |
| 1653 | + |
| 1654 | + td = TensorDict( |
| 1655 | + {"a": torch.randn(4), **{f"key_{i}": torch.randn(4) for i in range(20)}}, |
| 1656 | + batch_size=[4], |
| 1657 | + ) |
| 1658 | + first, second = _count_compiles(fn, td) |
| 1659 | + assert first == 1, f"Expected 1 compile frame, got {first}" |
| 1660 | + assert second == 1, f"Recompilation detected: {second} frames" |
| 1661 | + |
| 1662 | + def test_tc_getattr_no_recompile(self): |
| 1663 | + class BigTC(TensorClass["nocast"]): |
| 1664 | + a: torch.Tensor |
| 1665 | + b: torch.Tensor |
| 1666 | + c: torch.Tensor |
| 1667 | + d: torch.Tensor |
| 1668 | + e: torch.Tensor |
| 1669 | + |
| 1670 | + def fn(tc): |
| 1671 | + return tc.a + tc.b + tc.c + tc.d + tc.e |
| 1672 | + |
| 1673 | + tc = BigTC( |
| 1674 | + a=torch.randn(4), |
| 1675 | + b=torch.randn(4), |
| 1676 | + c=torch.randn(4), |
| 1677 | + d=torch.randn(4), |
| 1678 | + e=torch.randn(4), |
| 1679 | + batch_size=[4], |
| 1680 | + ) |
| 1681 | + first, second = _count_compiles(fn, tc) |
| 1682 | + assert first == 1, f"Expected 1 compile frame, got {first}" |
| 1683 | + assert second == 1, f"Recompilation detected: {second} frames" |
| 1684 | + |
| 1685 | + def test_replace_no_recompile(self): |
| 1686 | + class State(TensorClass["nocast"]): |
| 1687 | + x: torch.Tensor |
| 1688 | + y: torch.Tensor |
| 1689 | + z: torch.Tensor |
| 1690 | + |
| 1691 | + def fn(s): |
| 1692 | + s = s.replace(x=s.x + 1) |
| 1693 | + s = s.replace(y=s.y + 2) |
| 1694 | + s = s.replace(x=s.x + s.y, z=s.z + 1) |
| 1695 | + return s |
| 1696 | + |
| 1697 | + s = State( |
| 1698 | + x=torch.randn(4), |
| 1699 | + y=torch.randn(4), |
| 1700 | + z=torch.randn(4), |
| 1701 | + batch_size=[4], |
| 1702 | + ) |
| 1703 | + first, second = _count_compiles(fn, s) |
| 1704 | + assert first == 1, f"Expected 1 compile frame, got {first}" |
| 1705 | + assert second == 1, f"Recompilation detected: {second} frames" |
| 1706 | + |
| 1707 | + def test_update_inplace_no_recompile(self): |
| 1708 | + def fn(td, src): |
| 1709 | + td.update_(src) |
| 1710 | + return td["a"] + 0 |
| 1711 | + |
| 1712 | + td = TensorDict( |
| 1713 | + {"a": torch.randn(4), "b": torch.randn(4)}, |
| 1714 | + batch_size=[4], |
| 1715 | + ) |
| 1716 | + src = TensorDict( |
| 1717 | + {"a": torch.ones(4), "b": torch.ones(4)}, |
| 1718 | + batch_size=[4], |
| 1719 | + ) |
| 1720 | + first, second = _count_compiles(fn, td, src) |
| 1721 | + assert first == 1, f"Expected 1 compile frame, got {first}" |
| 1722 | + assert second == 1, f"Recompilation detected: {second} frames" |
| 1723 | + |
| 1724 | + def test_unbatched_clone_no_recompile(self): |
| 1725 | + def fn(td): |
| 1726 | + c = td.clone() |
| 1727 | + return c["a"] + 0 |
| 1728 | + |
| 1729 | + td = TensorDict( |
| 1730 | + { |
| 1731 | + "a": torch.randn(4, 3), |
| 1732 | + "unbatched": UnbatchedTensor(data=torch.randn(5)), |
| 1733 | + }, |
| 1734 | + batch_size=[4], |
| 1735 | + ) |
| 1736 | + first, second = _count_compiles(fn, td) |
| 1737 | + assert first == 1, f"Expected 1 compile frame, got {first}" |
| 1738 | + assert second == 1, f"Recompilation detected: {second} frames" |
| 1739 | + |
| 1740 | + def test_unbatched_clone_preserves_semantics(self): |
| 1741 | + """Cloning an UnbatchedTensor must produce independent data.""" |
| 1742 | + torch._dynamo.reset_code_caches() |
| 1743 | + |
| 1744 | + def fn(td): |
| 1745 | + cloned = td.clone() |
| 1746 | + return cloned |
| 1747 | + |
| 1748 | + td = TensorDict( |
| 1749 | + { |
| 1750 | + "a": torch.randn(4, 3), |
| 1751 | + "unbatched": UnbatchedTensor(data=torch.randn(5)), |
| 1752 | + }, |
| 1753 | + batch_size=[4], |
| 1754 | + ) |
| 1755 | + fn_c = torch.compile(fn, fullgraph=True) |
| 1756 | + result = fn_c(td) |
| 1757 | + ut_orig = td.get("unbatched") |
| 1758 | + ut_clone = result.get("unbatched") |
| 1759 | + assert ut_clone.data.data_ptr() != ut_orig.data.data_ptr(), ( |
| 1760 | + "clone() must produce independent data" |
| 1761 | + ) |
| 1762 | + |
| 1763 | + |
1626 | 1764 | if __name__ == "__main__": |
1627 | 1765 | args, unknown = argparse.ArgumentParser().parse_known_args() |
1628 | 1766 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) |
0 commit comments