Skip to content

Commit 2ec82b3

Browse files
Added new API for low precision fp8 attention using FA3 (#3857)
## Summary - Added RoPE fusion compile path for FA3 FP8 low-precision attention (fuse_rope=True) - New elementary block: fp8_fa3_rope_sdpa — fused RoPE + FP8 quantization + low-precision SDPA - New Triton kernel for fused RoPE + QKV quantization with layout transpose ([B,S,H,D] → [B,H,S,D]) - RoPE fusion method: Custom Inductor backend that traces the FX graph, detects RoPE + SDPA patterns (NeoX half-split and FLUX interleaved formats), and replaces them with fp8_fa3_rope_sdpa custom ops. Falls back to fp8_fa3_sdpa for SDPA nodes without RoPE. - Causal mask detection: Pre-flight forward pass identifies HuggingFace-style materialized causal masks so the fusion pass can strip them and use is_causal=True instead. - Added compiled model wrapper (_FP8FlashAttentionCompiledWrapper) with @torch._dynamo.disable to prevent re-tracing. - Added RoPE SDPA numerical accuracy tests and fuse_rope parametrization on model-level tests. ### New Files - shared_utils/fusion_utils.py: Shared FX graph fusion pass — RoPE pattern detection, SDPA detection, transpose unwrapping, parameterized graph surgery - shared_utils/custom_ops.py: Factory functions to register backend-specific custom ops with register_fake, and helpers to build fusion passes and compile functions - fp8_fa3/fusion_pass.py: FA3-specific custom op registration, rope_sdpa_fusion_pass, and compile_with_fp8_fusion entry point - quantization/triton_rope_qkv_quantization.py: Fused RoPE + QKV FP8 quantization Triton kernel ### Modified Files - shared_utils/attention.py: Added _fp8_rope_sdpa shared implementation - shared_utils/wrapper.py: Added _FP8FlashAttentionCompiledWrapper - shared_utils/setup.py: Added compile path routing via compile_fn parameter, moved detect_causal_mask to fusion_utils.py - quantization/quantization.py: Added _fp8_rope_sdpa_quantize - fp8_fa3/attention.py: Added fp8_fa3_rope_sdpa elementary block - fp8_fa3/setup.py: Passes compile_with_fp8_fusion as compile_fn - test_fp8_attention.py: Added TestFP8RopeSDPANumericalAccuracy, fuse_rope parametrization on model test ## Test Plan `python -m pytest test/prototype/attention/test_fp8_attention.py -v` ## Example Usage ```python from torchao.prototype.attention import ( AttentionBackend, LowPrecisionAttentionConfig, apply_low_precision_attention, ) model = MyModel() # Compile path with RoPE fusion config = LowPrecisionAttentionConfig( backend=AttentionBackend.FP8_FA3, fuse_rope=True, ) model = apply_low_precision_attention(model, config) # Flash activation is handled internally by the wrapper output = model(inputs) ``` --- ## Results #### Single-Layer Results Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time): <img width="645" height="373" alt="image" src="https://github.com/user-attachments/assets/f64cba6d-4ac9-41b7-b0f7-bf93c67f6c13" /> #### Llama3 Model Results Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Uses RoPE fusion. Perplexity: 6.19 -> 6.24 <img width="634" height="285" alt="image" src="https://github.com/user-attachments/assets/7e454bdc-b0d2-43e6-8293-32012abde5f8" />
1 parent aad1018 commit 2ec82b3

12 files changed

Lines changed: 2169 additions & 32 deletions

File tree

test/prototype/attention/test_fp8_attention.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,29 @@
3030
AttentionBackend,
3131
apply_low_precision_attention,
3232
)
33-
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
33+
from torchao.prototype.attention.fp8_fa3.attention import (
34+
fp8_fa3_rope_sdpa,
35+
fp8_fa3_sdpa,
36+
)
37+
38+
39+
def _rope_cos_sin(S, D, device):
40+
freqs = 1.0 / (10000.0 ** (torch.arange(0, D, 2, dtype=torch.float32) / D))
41+
angles = torch.outer(torch.arange(S, dtype=torch.float32), freqs)
42+
cos_half = torch.cos(angles)
43+
sin_half = torch.sin(angles)
44+
cos = torch.cat([cos_half, cos_half], dim=-1).to(device)
45+
sin = torch.cat([sin_half, sin_half], dim=-1).to(device)
46+
return cos, sin
47+
48+
49+
def _apply_rope(x, cos, sin):
50+
"""NeoX rotate-half RoPE. x: [B, S, H, D], cos/sin: [S, D]."""
51+
D_HALF = x.shape[-1] // 2
52+
rotate = torch.cat([-x[..., D_HALF:], x[..., :D_HALF]], dim=-1)
53+
return (
54+
x * cos.unsqueeze(0).unsqueeze(2) + rotate * sin.unsqueeze(0).unsqueeze(2)
55+
).to(x.dtype)
3456

3557

3658
class SimpleAttentionModel(nn.Module):
@@ -52,6 +74,30 @@ def forward(self, x):
5274
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))
5375

5476

77+
class SimpleRoPEAttentionModel(nn.Module):
78+
"""Applies RoPE to Q and K immediately before SDPA (Pattern A: RoPE → transpose → SDPA)."""
79+
80+
def __init__(self, embed_dim, num_heads):
81+
super().__init__()
82+
self.num_heads = num_heads
83+
self.head_dim = embed_dim // num_heads
84+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
85+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
86+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
87+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
88+
89+
def forward(self, x, cos, sin):
90+
B, S, _ = x.shape
91+
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim)
92+
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim)
93+
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim)
94+
q = _apply_rope(q, cos, sin).transpose(1, 2)
95+
k = _apply_rope(k, cos, sin).transpose(1, 2)
96+
v = v.transpose(1, 2)
97+
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
98+
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))
99+
100+
55101
@common_utils.instantiate_parametrized_tests
56102
class TestFP8FA3Attention(TestCase):
57103
@unittest.skipUnless(
@@ -83,6 +129,41 @@ def test_sdpa_accuracy(self, shape, dtype):
83129
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
84130
)
85131

132+
@unittest.skipUnless(
133+
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
134+
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
135+
)
136+
@common_utils.parametrize("shape", [(2, 1024, 8, 64), (1, 1024, 16, 128)])
137+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
138+
def test_rope_sdpa_accuracy(self, shape, dtype):
139+
B, S, H, D = shape
140+
q = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
141+
k = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
142+
v = torch.randn(B, S, H, D, device="cuda", dtype=dtype)
143+
cos, sin = _rope_cos_sin(S, D, "cuda")
144+
145+
with torch.no_grad():
146+
out_ref = F.scaled_dot_product_attention(
147+
_apply_rope(q, cos, sin).transpose(1, 2),
148+
_apply_rope(k, cos, sin).transpose(1, 2),
149+
v.transpose(1, 2),
150+
is_causal=False,
151+
)
152+
153+
activate_flash_attention_impl("FA3")
154+
try:
155+
with torch.no_grad():
156+
out_fp8 = fp8_fa3_rope_sdpa(q, k, v, cos, sin, is_causal=False)
157+
finally:
158+
restore_flash_attention_impl()
159+
160+
sqnr = compute_error(out_ref, out_fp8)
161+
self.assertGreater(
162+
sqnr.item(),
163+
25.0,
164+
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
165+
)
166+
86167
@unittest.skipUnless(
87168
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
88169
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
@@ -122,6 +203,48 @@ def test_monkey_patch_model(self, dtype):
122203
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
123204
)
124205

206+
@unittest.skipUnless(
207+
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
208+
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
209+
)
210+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
211+
def test_rope_fusion_model(self, dtype):
212+
embed_dim, num_heads = 512, 8
213+
model = (
214+
SimpleRoPEAttentionModel(embed_dim, num_heads)
215+
.to(device="cuda", dtype=dtype)
216+
.eval()
217+
)
218+
S = 128
219+
x = torch.randn(2, S, embed_dim, device="cuda", dtype=dtype)
220+
cos, sin = _rope_cos_sin(S, embed_dim // num_heads, "cuda")
221+
222+
with torch.no_grad():
223+
out_ref = model(x, cos, sin)
224+
225+
fp8_model = (
226+
SimpleRoPEAttentionModel(embed_dim, num_heads)
227+
.to(device="cuda", dtype=dtype)
228+
.eval()
229+
)
230+
fp8_model.load_state_dict(model.state_dict())
231+
fp8_model = apply_low_precision_attention(
232+
fp8_model,
233+
backend=AttentionBackend.FP8_FA3,
234+
fuse_rope_using_torch_compile=True,
235+
)
236+
fp8_model = torch.compile(fp8_model, backend=fp8_model.compile_backend)
237+
238+
with torch.no_grad():
239+
out_fp8 = fp8_model(x, cos, sin)
240+
241+
sqnr = compute_error(out_ref, out_fp8)
242+
self.assertGreater(
243+
sqnr.item(),
244+
20.0,
245+
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
246+
)
247+
125248

126249
if __name__ == "__main__":
127250
run_tests()

torchao/prototype/attention/api.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ def apply_low_precision_attention(
6767
Must be called before ``torch.compile``. KV caching should be
6868
disabled before calling (e.g., ``config.use_cache = False`` for
6969
HuggingFace models).
70+
71+
When ``fuse_rope_using_torch_compile=True``, the returned wrapper
72+
exposes a ``compile_backend`` attribute. You must compile with it to get
73+
the RoPE fusion::
74+
75+
model = apply_low_precision_attention(model, fuse_rope_using_torch_compile=True)
76+
model = torch.compile(model, backend=model.compile_backend)
7077
"""
7178
if isinstance(model, _LowPrecisionAttentionWrapper):
7279
raise RuntimeError(

torchao/prototype/attention/fp8_fa3/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""FP8 attention using FA3 backend."""
7+
"""
8+
FP8 attention using FA3 backend.
9+
"""
810

9-
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
11+
from torchao.prototype.attention.fp8_fa3.attention import (
12+
fp8_fa3_rope_sdpa,
13+
fp8_fa3_sdpa,
14+
)
1015
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize
1116

1217
__all__ = [
1318
"fp8_fa3_sdpa",
19+
"fp8_fa3_rope_sdpa",
1420
"_fp8_sdpa_quantize",
1521
]

torchao/prototype/attention/fp8_fa3/attention.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,32 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""FP8 SDPA using FA3 backend.
7+
"""
8+
FP8 SDPA using FA3 backend.
9+
10+
When using these functions directly (not through apply_low_precision_attention),
11+
you must activate FA3 yourself::
812
9-
Thin wrapper around ``shared_utils/attention.py``. When using directly,
10-
activate the FA3 flash attention implementation before calling.
13+
activate_flash_attention_impl("FA3")
14+
try:
15+
out = fp8_fa3_sdpa(q, k, v, is_causal=True)
16+
finally:
17+
restore_flash_attention_impl()
1118
"""
1219

1320
from functools import partial
1421

1522
from torchao.prototype.attention.shared_utils.attention import (
23+
_fp8_rope_sdpa,
1624
_fp8_sdpa,
1725
)
1826

1927
fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3")
2028
fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__
2129
fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa"
2230
fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa"
31+
32+
fp8_fa3_rope_sdpa = partial(_fp8_rope_sdpa, backend_name="FA3")
33+
fp8_fa3_rope_sdpa.__doc__ = _fp8_rope_sdpa.__doc__
34+
fp8_fa3_rope_sdpa.__name__ = "fp8_fa3_rope_sdpa"
35+
fp8_fa3_rope_sdpa.__qualname__ = "fp8_fa3_rope_sdpa"
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchao.prototype.attention.fp8_fa3.attention import (
8+
fp8_fa3_rope_sdpa,
9+
fp8_fa3_sdpa,
10+
)
11+
from torchao.prototype.attention.shared_utils.custom_ops import (
12+
make_backend_fn,
13+
register_fp8_attention_ops,
14+
)
15+
16+
_ops = register_fp8_attention_ops(
17+
backend_name="fa3",
18+
rope_sdpa_fn=fp8_fa3_rope_sdpa,
19+
sdpa_fn=fp8_fa3_sdpa,
20+
)
21+
22+
make_fp8_backend = make_backend_fn(_ops, backend_name="FA3", flash_impl_name="FA3")

torchao/prototype/attention/quantization/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""Shared FP8 quantization kernels for low-precision attention."""
8-
9-
from torchao.prototype.attention.quantization.quantization import (
10-
_fp8_sdpa_quantize,
7+
from torchao.prototype.attention.quantization.triton_qkv_quantization import (
8+
triton_fp8_sdpa_quantize as _fp8_sdpa_quantize,
9+
)
10+
from torchao.prototype.attention.quantization.triton_rope_qkv_quantization import (
11+
triton_fp8_rope_sdpa_quantize as _fp8_rope_sdpa_quantize,
1112
)
1213

1314
__all__ = [
1415
"_fp8_sdpa_quantize",
16+
"_fp8_rope_sdpa_quantize",
1517
]

0 commit comments

Comments
 (0)