@@ -81,16 +81,17 @@ def _reshape_cos_sin_to_2d(
8181 if len (shape ) < 2 :
8282 logger .debug ("RoPE %s has fewer than 2 dims: shape=%s" , name , shape )
8383 return None
84- for dim in shape [: - 2 ]:
85- if dim != 1 :
86- logger .debug (
87- "RoPE %s has non-unit leading dim : shape=%s" ,
88- name ,
89- shape ,
90- )
91- return None
84+ non_unit = [ d for d in shape if d != 1 ]
85+ if len ( non_unit ) != 2 :
86+ logger .debug (
87+ "RoPE %s cannot be squeezed to 2D : shape=%s" ,
88+ name ,
89+ shape ,
90+ )
91+ return None
9292
93- s , d = cos_shape [- 2 ], cos_shape [- 1 ]
93+ cos_non_unit = [d for d in cos_shape if d != 1 ]
94+ s , d = cos_non_unit [0 ], cos_non_unit [1 ]
9495 with graph .inserting_before (insert_before ):
9596 cos_2d = graph .call_function (
9697 torch .ops .aten .view .default ,
@@ -122,6 +123,7 @@ def _trace_through_views(node: Node) -> Node:
122123 "contiguous" ,
123124 "expand" ,
124125 "to" ,
126+ "type_as" ,
125127 "float" ,
126128 "half" ,
127129 "bfloat16" ,
@@ -690,9 +692,205 @@ def _try_match(x_mul: Node, rot_mul: Node) -> Optional[RoPEMatch]:
690692 return _try_match (right , left )
691693
692694
695+ def _parse_stride2_index (idx ) -> Optional [int ]:
696+ """Parse ``(..., slice(start, None, 2))`` index tuples.
697+
698+ Returns the slice ``start`` value (defaulting to 0), or None if the
699+ index does not match the expected stride-2 pattern.
700+ """
701+ if not isinstance (idx , tuple ) or len (idx ) < 1 :
702+ return None
703+ last = idx [- 1 ]
704+ if not isinstance (last , slice ) or last .step != 2 :
705+ return None
706+ for entry in idx [:- 1 ]:
707+ if entry is not Ellipsis and entry != slice (None ):
708+ return None
709+ return last .start if last .start is not None else 0
710+
711+
712+ def _get_setitem_stride2_slice (node : Node ) -> Optional [Tuple [int , Node ]]:
713+ """Check if node is ``operator.setitem(tensor, (..., slice(start, None, 2)), value)``.
714+
715+ Returns ``(start, value_node)`` or None.
716+ """
717+ if not _is_op (node , operator .setitem ):
718+ return None
719+ if len (node .args ) < 3 :
720+ return None
721+ value = node .args [2 ]
722+ if not isinstance (value , Node ):
723+ return None
724+ start = _parse_stride2_index (node .args [1 ])
725+ if start is None :
726+ return None
727+ return start , value
728+
729+
730+ def _get_getitem_stride2_slice (node : Node ) -> Optional [Tuple [int , Node ]]:
731+ """Check if node is ``operator.getitem(tensor, (..., slice(start, None, 2)))``.
732+
733+ Returns ``(start, source_tensor_node)`` or None.
734+ """
735+ if not _is_op (node , operator .getitem ):
736+ return None
737+ if len (node .args ) < 2 :
738+ return None
739+ source = node .args [0 ]
740+ if not isinstance (source , Node ):
741+ return None
742+ start = _parse_stride2_index (node .args [1 ])
743+ if start is None :
744+ return None
745+ return start , source
746+
747+
748+ def _find_freq_slice_arg (mul_node : Node ) -> Optional [Tuple [Node , Node ]]:
749+ """Find which arg of a mul is a stride-2 getitem slice (freq), return (getitem_node, other_arg).
750+
751+ Handles commutativity: checks both orderings.
752+ """
753+ if not _is_op (mul_node , torch .ops .aten .mul .Tensor , operator .mul ):
754+ return None
755+ a , b = mul_node .args [0 ], mul_node .args [1 ]
756+ for freq_candidate , other in [(a , b ), (b , a )]:
757+ if not isinstance (freq_candidate , Node ):
758+ continue
759+ traced = _trace_through_views (freq_candidate )
760+ gi = _get_getitem_stride2_slice (traced )
761+ if gi is not None :
762+ return traced , other
763+ return None
764+
765+
766+ def _get_freq_slice_start (node : Node ) -> Optional [int ]:
767+ """Get the start value from a stride-2 frequency slice node."""
768+ gi = _get_getitem_stride2_slice (node )
769+ if gi is not None :
770+ return gi [0 ]
771+ return None
772+
773+
774+ def _get_freq_slice_source (node : Node ) -> Node :
775+ """Get the source tensor from a stride-2 frequency slice node."""
776+ gi = _get_getitem_stride2_slice (node )
777+ if gi is not None :
778+ return _trace_through_views (gi [1 ])
779+ return node
780+
781+
782+ def _detect_wan_rope (node : Node ) -> Optional [RoPEMatch ]:
783+ """Detect Wan-style indexed-write RoPE pattern (Pattern C).
784+
785+ In the pre-grad FX graph (where the fusion pass runs), the pattern is:
786+
787+ out = torch.empty_like(hidden_states)
788+ ...
789+ setitem(out, (..., slice(0, None, 2)), sub) # even positions
790+ ...
791+ setitem(out, (..., slice(1, None, 2)), add) # odd positions
792+ type_as(out, hidden_states)
793+ transpose(out, 1, 2)
794+
795+ The detection starts from the ``empty_like`` node (reached after
796+ ``_unwrap_transpose`` + ``_trace_through_views`` strips the transpose and
797+ type_as). We then scan the users of ``empty_like`` for the two stride-2
798+ setitem writes.
799+
800+ This is mathematically identical to interleaved RoPE (pairs (2i, 2i+1)),
801+ so we return rope_interleaved=True.
802+ """
803+ # The node should be empty_like(hidden_states)
804+ if not _is_op (
805+ node ,
806+ torch .ops .aten .empty_like .default ,
807+ torch .empty_like ,
808+ ):
809+ return None
810+
811+ if not node .args or not isinstance (node .args [0 ], Node ):
812+ return None
813+ pre_rope_input = node .args [0 ]
814+
815+ # Find the two stride-2 setitem users: one with start=0 (even), one with start=1 (odd)
816+ even_val = None
817+ odd_val = None
818+ for user in node .users :
819+ si = _get_setitem_stride2_slice (user )
820+ if si is None :
821+ continue
822+ start , value = si
823+ if start == 0 and even_val is None :
824+ even_val = value
825+ elif start == 1 and odd_val is None :
826+ odd_val = value
827+
828+ if even_val is None or odd_val is None :
829+ return None
830+
831+ # even_val = sub(mul(x1, cos), mul(x2, sin))
832+ if not _is_op (even_val , torch .ops .aten .sub .Tensor , operator .sub ):
833+ return None
834+ if len (even_val .args ) < 2 :
835+ return None
836+ even_left , even_right = even_val .args [0 ], even_val .args [1 ]
837+ if not isinstance (even_left , Node ) or not isinstance (even_right , Node ):
838+ return None
839+ if not _is_op (even_left , torch .ops .aten .mul .Tensor , operator .mul ):
840+ return None
841+ if not _is_op (even_right , torch .ops .aten .mul .Tensor , operator .mul ):
842+ return None
843+
844+ # odd_val = add(mul(x1, sin), mul(x2, cos))
845+ if not _is_op (odd_val , torch .ops .aten .add .Tensor , operator .add ):
846+ return None
847+ if len (odd_val .args ) < 2 :
848+ return None
849+ odd_left , odd_right = odd_val .args [0 ], odd_val .args [1 ]
850+ if not isinstance (odd_left , Node ) or not isinstance (odd_right , Node ):
851+ return None
852+ if not _is_op (odd_left , torch .ops .aten .mul .Tensor , operator .mul ):
853+ return None
854+ if not _is_op (odd_right , torch .ops .aten .mul .Tensor , operator .mul ):
855+ return None
856+
857+ # From even_val's mul nodes, identify cos/sin (stride-2 slices) vs x1/x2
858+ even_left_match = _find_freq_slice_arg (even_left )
859+ even_right_match = _find_freq_slice_arg (even_right )
860+ if even_left_match is None or even_right_match is None :
861+ return None
862+
863+ cos_slice_node , _x1 = even_left_match
864+ sin_slice_node , _x2 = even_right_match
865+
866+ # Verify cos slice has start=0 (even positions) and sin slice has start=1 (odd positions)
867+ cos_start = _get_freq_slice_start (cos_slice_node )
868+ sin_start = _get_freq_slice_start (sin_slice_node )
869+ if cos_start != 0 or sin_start != 1 :
870+ # Try swapping
871+ if sin_start == 0 and cos_start == 1 :
872+ cos_slice_node , sin_slice_node = sin_slice_node , cos_slice_node
873+ else :
874+ return None
875+
876+ # Unwrap stride-2 slices to get original cos/sin tensors
877+ cos_original = _get_freq_slice_source (cos_slice_node )
878+ sin_original = _get_freq_slice_source (sin_slice_node )
879+
880+ return RoPEMatch (
881+ pre_rope_input = pre_rope_input ,
882+ cos_node = cos_original ,
883+ sin_node = sin_original ,
884+ rope_interleaved = True ,
885+ )
886+
887+
693888def _detect_rope (node : Node ) -> Optional [RoPEMatch ]:
694889 """Detect any supported RoPE variant at a given node."""
695- return _detect_neox_rope (node )
890+ result = _detect_neox_rope (node )
891+ if result is not None :
892+ return result
893+ return _detect_wan_rope (node )
696894
697895
698896# Graph Surgery
@@ -755,6 +953,7 @@ def rope_sdpa_fusion_pass(
755953 Supported patterns:
756954 - Pattern A (RoPE -> transpose -> FP8 SDPA): FLUX-style
757955 - Pattern B (transpose -> RoPE -> FP8 SDPA): HuggingFace-style
956+ - Pattern C (indexed-write RoPE -> transpose -> FP8 SDPA): Wan-style
758957
759958 Note: KV caching must be disabled before compilation.
760959 DynamicCache.update() inserts torch.cat nodes that break pattern matching.
0 commit comments