|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Unit tests for RoPE fusion pattern detection in the pre-grad custom pass. |
| 8 | +
|
| 9 | +These tests verify that rope_sdpa_fusion_pass correctly identifies each |
| 10 | +RoPE pattern in the FX graph, independent of any GPU kernel or hardware. |
| 11 | +""" |
| 12 | + |
| 13 | +import contextlib |
| 14 | +import io |
| 15 | +import unittest |
| 16 | +from functools import partial |
| 17 | + |
| 18 | +import torch |
| 19 | +import torch._inductor.config as inductor_config |
| 20 | +import torch.nn as nn |
| 21 | + |
| 22 | +from torchao.prototype.attention.shared_utils.custom_ops import ( |
| 23 | + register_fp8_attention_ops, |
| 24 | +) |
| 25 | +from torchao.prototype.attention.shared_utils.fusion_utils import rope_sdpa_fusion_pass |
| 26 | + |
| 27 | + |
| 28 | +# Register test-only custom ops with dummy implementations. |
| 29 | +def _dummy_rope_sdpa(q, k, v, cos, sin, **kwargs): |
| 30 | + B, S, H, D = q.shape |
| 31 | + return torch.zeros(B, H, S, D, dtype=q.dtype, device=q.device) |
| 32 | + |
| 33 | + |
| 34 | +def _dummy_sdpa(q, k, v, **kwargs): |
| 35 | + return torch.zeros_like(q) |
| 36 | + |
| 37 | + |
| 38 | +_ops = register_fp8_attention_ops("test_fusion", _dummy_rope_sdpa, _dummy_sdpa) |
| 39 | + |
| 40 | + |
| 41 | +class PatternANeoXRoPE(nn.Module): |
| 42 | + """Pattern A: NeoX rotate_half RoPE -> transpose -> FP8 SDPA.""" |
| 43 | + |
| 44 | + def forward(self, q, k, v, cos, sin): |
| 45 | + d_half = q.shape[-1] // 2 |
| 46 | + cos = cos.unsqueeze(0).unsqueeze(2) |
| 47 | + sin = sin.unsqueeze(0).unsqueeze(2) |
| 48 | + |
| 49 | + q_rot = torch.cat([-q[..., d_half:], q[..., :d_half]], dim=-1) |
| 50 | + q = q * cos + q_rot * sin |
| 51 | + |
| 52 | + k_rot = torch.cat([-k[..., d_half:], k[..., :d_half]], dim=-1) |
| 53 | + k = k * cos + k_rot * sin |
| 54 | + |
| 55 | + return _ops.fp8_sdpa_op( |
| 56 | + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True |
| 57 | + ) |
| 58 | + |
| 59 | + |
| 60 | +class PatternBHuggingFaceRoPE(nn.Module): |
| 61 | + """Pattern B: transpose -> NeoX RoPE -> FP8 SDPA (HuggingFace-style).""" |
| 62 | + |
| 63 | + def forward(self, q, k, v, cos, sin): |
| 64 | + q = q.transpose(1, 2) |
| 65 | + k = k.transpose(1, 2) |
| 66 | + v = v.transpose(1, 2) |
| 67 | + |
| 68 | + d_half = q.shape[-1] // 2 |
| 69 | + cos = cos.unsqueeze(0).unsqueeze(1) # [1, 1, S, D] for BHSD layout |
| 70 | + sin = sin.unsqueeze(0).unsqueeze(1) |
| 71 | + |
| 72 | + q_rot = torch.cat([-q[..., d_half:], q[..., :d_half]], dim=-1) |
| 73 | + q = q * cos + q_rot * sin |
| 74 | + |
| 75 | + k_rot = torch.cat([-k[..., d_half:], k[..., :d_half]], dim=-1) |
| 76 | + k = k * cos + k_rot * sin |
| 77 | + |
| 78 | + return _ops.fp8_sdpa_op(q, k, v, is_causal=True) |
| 79 | + |
| 80 | + |
| 81 | +class PatternCWanRoPE(nn.Module): |
| 82 | + """Pattern C: Wan-style indexed-write RoPE -> transpose -> FP8 SDPA.""" |
| 83 | + |
| 84 | + def forward(self, q, k, v, freqs): |
| 85 | + out_q = torch.empty_like(q) |
| 86 | + out_q[..., 0::2] = ( |
| 87 | + freqs[..., 0::2] * q[..., 0::2] - freqs[..., 1::2] * q[..., 1::2] |
| 88 | + ) |
| 89 | + out_q[..., 1::2] = ( |
| 90 | + freqs[..., 1::2] * q[..., 0::2] + freqs[..., 0::2] * q[..., 1::2] |
| 91 | + ) |
| 92 | + out_q = out_q.type_as(q) |
| 93 | + |
| 94 | + out_k = torch.empty_like(k) |
| 95 | + out_k[..., 0::2] = ( |
| 96 | + freqs[..., 0::2] * k[..., 0::2] - freqs[..., 1::2] * k[..., 1::2] |
| 97 | + ) |
| 98 | + out_k[..., 1::2] = ( |
| 99 | + freqs[..., 1::2] * k[..., 0::2] + freqs[..., 0::2] * k[..., 1::2] |
| 100 | + ) |
| 101 | + out_k = out_k.type_as(k) |
| 102 | + |
| 103 | + return _ops.fp8_sdpa_op( |
| 104 | + out_q.transpose(1, 2), |
| 105 | + out_k.transpose(1, 2), |
| 106 | + v.transpose(1, 2), |
| 107 | + is_causal=True, |
| 108 | + ) |
| 109 | + |
| 110 | + |
| 111 | +class TestRoPEFusionDetection(unittest.TestCase): |
| 112 | + def setUp(self): |
| 113 | + torch._dynamo.reset() |
| 114 | + self._old_pass = inductor_config.pre_grad_custom_pass |
| 115 | + |
| 116 | + def tearDown(self): |
| 117 | + inductor_config.pre_grad_custom_pass = self._old_pass |
| 118 | + torch._dynamo.reset() |
| 119 | + |
| 120 | + def _run_fusion_pass(self, model, *args): |
| 121 | + """Compile model with fusion pass, return captured stdout.""" |
| 122 | + inductor_config.pre_grad_custom_pass = partial( |
| 123 | + rope_sdpa_fusion_pass, |
| 124 | + rope_sdpa_op=_ops.rope_sdpa_op, |
| 125 | + fp8_sdpa_op=_ops.fp8_sdpa_op, |
| 126 | + backend_name="TEST", |
| 127 | + ) |
| 128 | + compiled = torch.compile(model) |
| 129 | + buf = io.StringIO() |
| 130 | + with torch.no_grad(), contextlib.redirect_stdout(buf): |
| 131 | + compiled(*args) |
| 132 | + return buf.getvalue() |
| 133 | + |
| 134 | + def _assert_fused(self, model, *extra_args): |
| 135 | + """Create BSHD inputs, run fusion pass, assert 1 node was fused.""" |
| 136 | + B, S, H, D = 1, 32, 4, 64 |
| 137 | + q = torch.randn(B, S, H, D) |
| 138 | + k = torch.randn(B, S, H, D) |
| 139 | + v = torch.randn(B, S, H, D) |
| 140 | + output = self._run_fusion_pass(model, q, k, v, *extra_args) |
| 141 | + self.assertIn("1 fused with RoPE", output) |
| 142 | + |
| 143 | + def test_pattern_a_neox_rope(self): |
| 144 | + S, D = 32, 64 |
| 145 | + self._assert_fused(PatternANeoXRoPE(), torch.randn(S, D), torch.randn(S, D)) |
| 146 | + |
| 147 | + def test_pattern_b_huggingface_rope(self): |
| 148 | + S, D = 32, 64 |
| 149 | + self._assert_fused( |
| 150 | + PatternBHuggingFaceRoPE(), torch.randn(S, D), torch.randn(S, D) |
| 151 | + ) |
| 152 | + |
| 153 | + def test_pattern_c_wan_rope(self): |
| 154 | + S, D = 32, 64 |
| 155 | + self._assert_fused(PatternCWanRoPE(), torch.randn(1, S, 1, D)) |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + unittest.main() |
0 commit comments