Skip to content

Commit 075de7c

Browse files
authored
[BugFix] Fix replace() recompiles under torch.compile (#1605)
1 parent 1f1686c commit 075de7c

2 files changed

Lines changed: 43 additions & 0 deletions

File tree

tensordict/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8381,6 +8381,13 @@ def replace(self, *args, **kwargs):
83818381
and the kwargs are empty, ``self`` is returned.
83828382

83838383
"""
8384+
if is_compiling() and not args:
8385+
if not kwargs:
8386+
return self
8387+
result = self.copy()
8388+
for k, v in kwargs.items():
8389+
result[k] = v
8390+
return result
83848391
if args:
83858392
if len(args) > 1:
83868393
raise RuntimeError(

test/test_compile.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,42 @@ def step(s):
786786
compiled_result = step_c(s)
787787
assert_close(eager_result, compiled_result)
788788

789+
def test_td_replace_no_recompile(self, mode):
790+
"""replace() with many distinct kwarg patterns must not recompile."""
791+
792+
class State(TensorClass["nocast"]):
793+
x: torch.Tensor
794+
y: torch.Tensor
795+
z: torch.Tensor
796+
w: torch.Tensor
797+
v: torch.Tensor
798+
799+
def step(s: State) -> State:
800+
s = s.replace(x=s.x + 1)
801+
s = s.replace(y=s.y + 2)
802+
s = s.replace(z=s.z + 3)
803+
s = s.replace(x=s.x * 0.9, y=s.y * 0.9)
804+
s = s.replace(w=s.w + s.x)
805+
s = s.replace(v=s.v - 1, w=s.w + 1)
806+
s = s.replace(x=s.x + s.v, y=s.y + s.w, z=s.z + 0.1)
807+
s = s.replace(v=torch.zeros_like(s.v))
808+
s = s.replace(w=torch.ones_like(s.w))
809+
s = s.replace(x=s.x + s.y + s.z)
810+
return s
811+
812+
s = State(
813+
x=torch.randn(4),
814+
y=torch.randn(4),
815+
z=torch.randn(4),
816+
w=torch.randn(4),
817+
v=torch.randn(4),
818+
batch_size=[4],
819+
)
820+
step_c = torch.compile(step, fullgraph=True, mode=mode)
821+
eager_result = step(s)
822+
compiled_result = step_c(s)
823+
assert_close(eager_result, compiled_result)
824+
789825
@pytest.mark.skipif(
790826
TORCH_VERSION < version.parse("2.6.0"),
791827
reason="while_loop requires torch>=2.6",

0 commit comments

Comments
 (0)