Skip to content

Commit f11eff8

Browse files
add rope fusion detection tests (#4183)
## Summary - Added some rope fusion detection tests for the 3 rope fusion patterns we have now
1 parent ce07646 commit f11eff8

1 file changed

Lines changed: 159 additions & 0 deletions

File tree

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Unit tests for RoPE fusion pattern detection in the pre-grad custom pass.
8+
9+
These tests verify that rope_sdpa_fusion_pass correctly identifies each
10+
RoPE pattern in the FX graph, independent of any GPU kernel or hardware.
11+
"""
12+
13+
import contextlib
14+
import io
15+
import unittest
16+
from functools import partial
17+
18+
import torch
19+
import torch._inductor.config as inductor_config
20+
import torch.nn as nn
21+
22+
from torchao.prototype.attention.shared_utils.custom_ops import (
23+
register_fp8_attention_ops,
24+
)
25+
from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass
26+
27+
28+
# Register test-only custom ops with dummy implementations.
29+
def _dummy_rope_sdpa(q, k, v, cos, sin, **kwargs):
30+
B, S, H, D = q.shape
31+
return torch.zeros(B, H, S, D, dtype=q.dtype, device=q.device)
32+
33+
34+
def _dummy_sdpa(q, k, v, **kwargs):
35+
return torch.zeros_like(q)
36+
37+
38+
_ops = register_fp8_attention_ops("test_fusion", _dummy_rope_sdpa, _dummy_sdpa)
39+
40+
41+
class PatternANeoXRoPE(nn.Module):
42+
"""Pattern A: NeoX rotate_half RoPE -> transpose -> FP8 SDPA."""
43+
44+
def forward(self, q, k, v, cos, sin):
45+
d_half = q.shape[-1] // 2
46+
cos = cos.unsqueeze(0).unsqueeze(2)
47+
sin = sin.unsqueeze(0).unsqueeze(2)
48+
49+
q_rot = torch.cat([-q[..., d_half:], q[..., :d_half]], dim=-1)
50+
q = q * cos + q_rot * sin
51+
52+
k_rot = torch.cat([-k[..., d_half:], k[..., :d_half]], dim=-1)
53+
k = k * cos + k_rot * sin
54+
55+
return _ops.fp8_sdpa_op(
56+
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True
57+
)
58+
59+
60+
class PatternBHuggingFaceRoPE(nn.Module):
61+
"""Pattern B: transpose -> NeoX RoPE -> FP8 SDPA (HuggingFace-style)."""
62+
63+
def forward(self, q, k, v, cos, sin):
64+
q = q.transpose(1, 2)
65+
k = k.transpose(1, 2)
66+
v = v.transpose(1, 2)
67+
68+
d_half = q.shape[-1] // 2
69+
cos = cos.unsqueeze(0).unsqueeze(1) # [1, 1, S, D] for BHSD layout
70+
sin = sin.unsqueeze(0).unsqueeze(1)
71+
72+
q_rot = torch.cat([-q[..., d_half:], q[..., :d_half]], dim=-1)
73+
q = q * cos + q_rot * sin
74+
75+
k_rot = torch.cat([-k[..., d_half:], k[..., :d_half]], dim=-1)
76+
k = k * cos + k_rot * sin
77+
78+
return _ops.fp8_sdpa_op(q, k, v, is_causal=True)
79+
80+
81+
class PatternCWanRoPE(nn.Module):
82+
"""Pattern C: Wan-style indexed-write RoPE -> transpose -> FP8 SDPA."""
83+
84+
def forward(self, q, k, v, freqs):
85+
out_q = torch.empty_like(q)
86+
out_q[..., 0::2] = (
87+
freqs[..., 0::2] * q[..., 0::2] - freqs[..., 1::2] * q[..., 1::2]
88+
)
89+
out_q[..., 1::2] = (
90+
freqs[..., 1::2] * q[..., 0::2] + freqs[..., 0::2] * q[..., 1::2]
91+
)
92+
out_q = out_q.type_as(q)
93+
94+
out_k = torch.empty_like(k)
95+
out_k[..., 0::2] = (
96+
freqs[..., 0::2] * k[..., 0::2] - freqs[..., 1::2] * k[..., 1::2]
97+
)
98+
out_k[..., 1::2] = (
99+
freqs[..., 1::2] * k[..., 0::2] + freqs[..., 0::2] * k[..., 1::2]
100+
)
101+
out_k = out_k.type_as(k)
102+
103+
return _ops.fp8_sdpa_op(
104+
out_q.transpose(1, 2),
105+
out_k.transpose(1, 2),
106+
v.transpose(1, 2),
107+
is_causal=True,
108+
)
109+
110+
111+
class TestRoPEFusionDetection(unittest.TestCase):
112+
def setUp(self):
113+
torch._dynamo.reset()
114+
self._old_pass = inductor_config.pre_grad_custom_pass
115+
116+
def tearDown(self):
117+
inductor_config.pre_grad_custom_pass = self._old_pass
118+
torch._dynamo.reset()
119+
120+
def _run_fusion_pass(self, model, *args):
121+
"""Compile model with fusion pass, return captured stdout."""
122+
inductor_config.pre_grad_custom_pass = partial(
123+
rope_sdpa_fusion_pass,
124+
rope_sdpa_op=_ops.rope_sdpa_op,
125+
fp8_sdpa_op=_ops.fp8_sdpa_op,
126+
backend_name="TEST",
127+
)
128+
compiled = torch.compile(model)
129+
buf = io.StringIO()
130+
with torch.no_grad(), contextlib.redirect_stdout(buf):
131+
compiled(*args)
132+
return buf.getvalue()
133+
134+
def _assert_fused(self, model, *extra_args):
135+
"""Create BSHD inputs, run fusion pass, assert 1 node was fused."""
136+
B, S, H, D = 1, 32, 4, 64
137+
q = torch.randn(B, S, H, D)
138+
k = torch.randn(B, S, H, D)
139+
v = torch.randn(B, S, H, D)
140+
output = self._run_fusion_pass(model, q, k, v, *extra_args)
141+
self.assertIn("1 fused with RoPE", output)
142+
143+
def test_pattern_a_neox_rope(self):
144+
S, D = 32, 64
145+
self._assert_fused(PatternANeoXRoPE(), torch.randn(S, D), torch.randn(S, D))
146+
147+
def test_pattern_b_huggingface_rope(self):
148+
S, D = 32, 64
149+
self._assert_fused(
150+
PatternBHuggingFaceRoPE(), torch.randn(S, D), torch.randn(S, D)
151+
)
152+
153+
def test_pattern_c_wan_rope(self):
154+
S, D = 32, 64
155+
self._assert_fused(PatternCWanRoPE(), torch.randn(1, S, 1, D))
156+
157+
158+
if __name__ == "__main__":
159+
unittest.main()

0 commit comments

Comments
 (0)