@@ -786,6 +786,42 @@ def step(s):
786786 compiled_result = step_c (s )
787787 assert_close (eager_result , compiled_result )
788788
789+ def test_td_replace_no_recompile (self , mode ):
790+ """replace() with many distinct kwarg patterns must not recompile."""
791+
792+ class State (TensorClass ["nocast" ]):
793+ x : torch .Tensor
794+ y : torch .Tensor
795+ z : torch .Tensor
796+ w : torch .Tensor
797+ v : torch .Tensor
798+
799+ def step (s : State ) -> State :
800+ s = s .replace (x = s .x + 1 )
801+ s = s .replace (y = s .y + 2 )
802+ s = s .replace (z = s .z + 3 )
803+ s = s .replace (x = s .x * 0.9 , y = s .y * 0.9 )
804+ s = s .replace (w = s .w + s .x )
805+ s = s .replace (v = s .v - 1 , w = s .w + 1 )
806+ s = s .replace (x = s .x + s .v , y = s .y + s .w , z = s .z + 0.1 )
807+ s = s .replace (v = torch .zeros_like (s .v ))
808+ s = s .replace (w = torch .ones_like (s .w ))
809+ s = s .replace (x = s .x + s .y + s .z )
810+ return s
811+
812+ s = State (
813+ x = torch .randn (4 ),
814+ y = torch .randn (4 ),
815+ z = torch .randn (4 ),
816+ w = torch .randn (4 ),
817+ v = torch .randn (4 ),
818+ batch_size = [4 ],
819+ )
820+ step_c = torch .compile (step , fullgraph = True , mode = mode )
821+ eager_result = step (s )
822+ compiled_result = step_c (s )
823+ assert_close (eager_result , compiled_result )
824+
789825 @pytest .mark .skipif (
790826 TORCH_VERSION < version .parse ("2.6.0" ),
791827 reason = "while_loop requires torch>=2.6" ,
0 commit comments