Skip to content

Commit 6f56403

Browse files
Add WAN style rope detection pattern for low precision attention API (#4173)
## Quick Summary - Added rope fusion detection for WAN style models that use RoPE. This is for the low precision attention API, to enable rope fusion. - Previously, the fallback path would work, allowing for torch.compile to work with the monkey patched F.SDPA. However, the RoPE operation was not fused into a single kernel with the quantization to fp8. ## Results Previous results on WAN model: | Config | Median Time (s) | Speedup | | --- | --- | --- | | bf16 baseline | 100.85 | 1.00x | | fp8_attn | 67.15 | **1.50x** | | bf16 baseline + torch.compile | 81.32 | 1.00x | | fp8_attn + compile | 46.88 | **1.73x** | New results with rope fusion: | Config | Median Time (s) | Speedup | | --- | --- | --- | | fp8_attn + rope fusion + compile | 44.24 | **1.84x** |
1 parent 2c4ecb0 commit 6f56403

1 file changed

Lines changed: 209 additions & 10 deletions

File tree

torchao/prototype/attention/shared_utils/fusion_utils.py

Lines changed: 209 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,17 @@ def _reshape_cos_sin_to_2d(
8181
if len(shape) < 2:
8282
logger.debug("RoPE %s has fewer than 2 dims: shape=%s", name, shape)
8383
return None
84-
for dim in shape[:-2]:
85-
if dim != 1:
86-
logger.debug(
87-
"RoPE %s has non-unit leading dim: shape=%s",
88-
name,
89-
shape,
90-
)
91-
return None
84+
non_unit = [d for d in shape if d != 1]
85+
if len(non_unit) != 2:
86+
logger.debug(
87+
"RoPE %s cannot be squeezed to 2D: shape=%s",
88+
name,
89+
shape,
90+
)
91+
return None
9292

93-
s, d = cos_shape[-2], cos_shape[-1]
93+
cos_non_unit = [d for d in cos_shape if d != 1]
94+
s, d = cos_non_unit[0], cos_non_unit[1]
9495
with graph.inserting_before(insert_before):
9596
cos_2d = graph.call_function(
9697
torch.ops.aten.view.default,
@@ -122,6 +123,7 @@ def _trace_through_views(node: Node) -> Node:
122123
"contiguous",
123124
"expand",
124125
"to",
126+
"type_as",
125127
"float",
126128
"half",
127129
"bfloat16",
@@ -690,9 +692,205 @@ def _try_match(x_mul: Node, rot_mul: Node) -> Optional[RoPEMatch]:
690692
return _try_match(right, left)
691693

692694

695+
def _parse_stride2_index(idx) -> Optional[int]:
696+
"""Parse ``(..., slice(start, None, 2))`` index tuples.
697+
698+
Returns the slice ``start`` value (defaulting to 0), or None if the
699+
index does not match the expected stride-2 pattern.
700+
"""
701+
if not isinstance(idx, tuple) or len(idx) < 1:
702+
return None
703+
last = idx[-1]
704+
if not isinstance(last, slice) or last.step != 2:
705+
return None
706+
for entry in idx[:-1]:
707+
if entry is not Ellipsis and entry != slice(None):
708+
return None
709+
return last.start if last.start is not None else 0
710+
711+
712+
def _get_setitem_stride2_slice(node: Node) -> Optional[Tuple[int, Node]]:
713+
"""Check if node is ``operator.setitem(tensor, (..., slice(start, None, 2)), value)``.
714+
715+
Returns ``(start, value_node)`` or None.
716+
"""
717+
if not _is_op(node, operator.setitem):
718+
return None
719+
if len(node.args) < 3:
720+
return None
721+
value = node.args[2]
722+
if not isinstance(value, Node):
723+
return None
724+
start = _parse_stride2_index(node.args[1])
725+
if start is None:
726+
return None
727+
return start, value
728+
729+
730+
def _get_getitem_stride2_slice(node: Node) -> Optional[Tuple[int, Node]]:
731+
"""Check if node is ``operator.getitem(tensor, (..., slice(start, None, 2)))``.
732+
733+
Returns ``(start, source_tensor_node)`` or None.
734+
"""
735+
if not _is_op(node, operator.getitem):
736+
return None
737+
if len(node.args) < 2:
738+
return None
739+
source = node.args[0]
740+
if not isinstance(source, Node):
741+
return None
742+
start = _parse_stride2_index(node.args[1])
743+
if start is None:
744+
return None
745+
return start, source
746+
747+
748+
def _find_freq_slice_arg(mul_node: Node) -> Optional[Tuple[Node, Node]]:
749+
"""Find which arg of a mul is a stride-2 getitem slice (freq), return (getitem_node, other_arg).
750+
751+
Handles commutativity: checks both orderings.
752+
"""
753+
if not _is_op(mul_node, torch.ops.aten.mul.Tensor, operator.mul):
754+
return None
755+
a, b = mul_node.args[0], mul_node.args[1]
756+
for freq_candidate, other in [(a, b), (b, a)]:
757+
if not isinstance(freq_candidate, Node):
758+
continue
759+
traced = _trace_through_views(freq_candidate)
760+
gi = _get_getitem_stride2_slice(traced)
761+
if gi is not None:
762+
return traced, other
763+
return None
764+
765+
766+
def _get_freq_slice_start(node: Node) -> Optional[int]:
767+
"""Get the start value from a stride-2 frequency slice node."""
768+
gi = _get_getitem_stride2_slice(node)
769+
if gi is not None:
770+
return gi[0]
771+
return None
772+
773+
774+
def _get_freq_slice_source(node: Node) -> Node:
775+
"""Get the source tensor from a stride-2 frequency slice node."""
776+
gi = _get_getitem_stride2_slice(node)
777+
if gi is not None:
778+
return _trace_through_views(gi[1])
779+
return node
780+
781+
782+
def _detect_wan_rope(node: Node) -> Optional[RoPEMatch]:
783+
"""Detect Wan-style indexed-write RoPE pattern (Pattern C).
784+
785+
In the pre-grad FX graph (where the fusion pass runs), the pattern is:
786+
787+
out = torch.empty_like(hidden_states)
788+
...
789+
setitem(out, (..., slice(0, None, 2)), sub) # even positions
790+
...
791+
setitem(out, (..., slice(1, None, 2)), add) # odd positions
792+
type_as(out, hidden_states)
793+
transpose(out, 1, 2)
794+
795+
The detection starts from the ``empty_like`` node (reached after
796+
``_unwrap_transpose`` + ``_trace_through_views`` strips the transpose and
797+
type_as). We then scan the users of ``empty_like`` for the two stride-2
798+
setitem writes.
799+
800+
This is mathematically identical to interleaved RoPE (pairs (2i, 2i+1)),
801+
so we return rope_interleaved=True.
802+
"""
803+
# The node should be empty_like(hidden_states)
804+
if not _is_op(
805+
node,
806+
torch.ops.aten.empty_like.default,
807+
torch.empty_like,
808+
):
809+
return None
810+
811+
if not node.args or not isinstance(node.args[0], Node):
812+
return None
813+
pre_rope_input = node.args[0]
814+
815+
# Find the two stride-2 setitem users: one with start=0 (even), one with start=1 (odd)
816+
even_val = None
817+
odd_val = None
818+
for user in node.users:
819+
si = _get_setitem_stride2_slice(user)
820+
if si is None:
821+
continue
822+
start, value = si
823+
if start == 0 and even_val is None:
824+
even_val = value
825+
elif start == 1 and odd_val is None:
826+
odd_val = value
827+
828+
if even_val is None or odd_val is None:
829+
return None
830+
831+
# even_val = sub(mul(x1, cos), mul(x2, sin))
832+
if not _is_op(even_val, torch.ops.aten.sub.Tensor, operator.sub):
833+
return None
834+
if len(even_val.args) < 2:
835+
return None
836+
even_left, even_right = even_val.args[0], even_val.args[1]
837+
if not isinstance(even_left, Node) or not isinstance(even_right, Node):
838+
return None
839+
if not _is_op(even_left, torch.ops.aten.mul.Tensor, operator.mul):
840+
return None
841+
if not _is_op(even_right, torch.ops.aten.mul.Tensor, operator.mul):
842+
return None
843+
844+
# odd_val = add(mul(x1, sin), mul(x2, cos))
845+
if not _is_op(odd_val, torch.ops.aten.add.Tensor, operator.add):
846+
return None
847+
if len(odd_val.args) < 2:
848+
return None
849+
odd_left, odd_right = odd_val.args[0], odd_val.args[1]
850+
if not isinstance(odd_left, Node) or not isinstance(odd_right, Node):
851+
return None
852+
if not _is_op(odd_left, torch.ops.aten.mul.Tensor, operator.mul):
853+
return None
854+
if not _is_op(odd_right, torch.ops.aten.mul.Tensor, operator.mul):
855+
return None
856+
857+
# From even_val's mul nodes, identify cos/sin (stride-2 slices) vs x1/x2
858+
even_left_match = _find_freq_slice_arg(even_left)
859+
even_right_match = _find_freq_slice_arg(even_right)
860+
if even_left_match is None or even_right_match is None:
861+
return None
862+
863+
cos_slice_node, _x1 = even_left_match
864+
sin_slice_node, _x2 = even_right_match
865+
866+
# Verify cos slice has start=0 (even positions) and sin slice has start=1 (odd positions)
867+
cos_start = _get_freq_slice_start(cos_slice_node)
868+
sin_start = _get_freq_slice_start(sin_slice_node)
869+
if cos_start != 0 or sin_start != 1:
870+
# Try swapping
871+
if sin_start == 0 and cos_start == 1:
872+
cos_slice_node, sin_slice_node = sin_slice_node, cos_slice_node
873+
else:
874+
return None
875+
876+
# Unwrap stride-2 slices to get original cos/sin tensors
877+
cos_original = _get_freq_slice_source(cos_slice_node)
878+
sin_original = _get_freq_slice_source(sin_slice_node)
879+
880+
return RoPEMatch(
881+
pre_rope_input=pre_rope_input,
882+
cos_node=cos_original,
883+
sin_node=sin_original,
884+
rope_interleaved=True,
885+
)
886+
887+
693888
def _detect_rope(node: Node) -> Optional[RoPEMatch]:
694889
"""Detect any supported RoPE variant at a given node."""
695-
return _detect_neox_rope(node)
890+
result = _detect_neox_rope(node)
891+
if result is not None:
892+
return result
893+
return _detect_wan_rope(node)
696894

697895

698896
# Graph Surgery
@@ -755,6 +953,7 @@ def rope_sdpa_fusion_pass(
755953
Supported patterns:
756954
- Pattern A (RoPE -> transpose -> FP8 SDPA): FLUX-style
757955
- Pattern B (transpose -> RoPE -> FP8 SDPA): HuggingFace-style
956+
- Pattern C (indexed-write RoPE -> transpose -> FP8 SDPA): Wan-style
758957
759958
Note: KV caching must be disabled before compilation.
760959
DynamicCache.update() inserts torch.cat nodes that break pattern matching.

0 commit comments

Comments
 (0)