Skip to content

Commit c32dea9

Browse files
Add FP8 FA3 low-precision attention with monkey-patch SDPA path (#3959)
## Summary - Added new folder for low-precision attention APIs in torchao/prototype/attention - New API for FP8 FA3 low-precision attention with two components: - Elementary block: fp8_fa3_sdpa — a direct drop-in replacement for F.scaled_dot_product_attention that users can integrate into their model manually. Performs per-head FP8 quantization of Q, K, V followed by low-precision SDPA. - Simple wrapper: apply_low_precision_attention — wraps any model to automatically replace all SDPA calls with the FP8 variant. No torch.compile required. - New Triton kernel for fused QKV FP8 quantization (3-phase: absmax reduction, scale computation, quantize) - Causal mask detection: Pre-flight forward pass identifies HuggingFace-style materialized causal masks so the wrapper can strip them and use is_causal=True instead. - Flash attention activation is handled internally by the wrapper — no manual activate_flash_attention_impl / restore_flash_attention_impl calls needed. - Added new test folder for low-precision attention APIs in test/prototype/attention ### Folder Breakdown - torchao/prototype/attention: new folder for low-precision attention APIs - __init__.py: public exports (apply_low_precision_attention, AttentionBackend, LowPrecisionAttentionConfig) - api.py: user-facing entry point that validates config and dispatches to the correct backend - config.py: AttentionBackend enum and LowPrecisionAttentionConfig dataclass - utils.py: hardware capability checks, backend availability detection - shared_utils/: shared infrastructure used by backend implementations - attention.py: shared _fp8_sdpa implementation (quantize + SDPA) - wrapper.py: _FP8FlashAttentionMonkeyPatchWrapper — replaces F.scaled_dot_product_attention during forward, manages flash activation internally - setup.py: setup_fp8_backend — builds the wrapper with causal mask detection - fp8_fa3/: FA3-specific backend - attention.py: fp8_fa3_sdpa elementary block - setup.py: thin wrapper calling setup_fp8_backend with FA3 parameters - quantization/: shared FP8 quantization kernels - quantization.py: _fp8_sdpa_quantize — calls fused Triton kernels for per-head Q, K, V quantization - triton_qkv_quantization.py: fused QKV FP8 quantization Triton kernel - test/prototype/attention: new folder for low-precision attention API tests - test_fp8_attention.py: numerical accuracy tests (eager SDPA) and model-level API tests (simple wrapper) ## Test Plan `python -m pytest test/prototype/attention/test_fp8_attention.py -v` ## Example Usage ```python from torchao.prototype.attention import ( AttentionBackend, LowPrecisionAttentionConfig, apply_low_precision_attention, ) model = MyModel() # Simple SDPA replacement — no torch.compile needed config = LowPrecisionAttentionConfig(backend=AttentionBackend.FP8_FA3) model = apply_low_precision_attention(model, config) # Flash activation is handled internally by the wrapper output = model(inputs) ``` --- ## Results #### Single-Layer Results Results directly comparing FA3 SDPA versus FA3 fp8 SDPA (including quantization time): <img width="645" height="373" alt="image" src="https://github.com/user-attachments/assets/f64cba6d-4ac9-41b7-b0f7-bf93c67f6c13" /> #### Llama3 Model Results Results comparing Llama3 model with FA3 SDPA versus Llama3 using the FA3 fp8 wrapper. Does not use RoPE fusion. Perplexity: 6.19 -> 6.25 <img width="368" height="171" alt="image" src="https://github.com/user-attachments/assets/f7534b8b-6914-4e6b-b568-55e38bdfcb2b" />
1 parent 67e5358 commit c32dea9

15 files changed

Lines changed: 1081 additions & 0 deletions

File tree

test/prototype/attention/__init__.py

Whitespace-only changes.
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Tests for FP8 low-precision attention (FA3 backend on Hopper)."""
8+
9+
import unittest
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
from torch.testing._internal import common_utils
15+
from torch.testing._internal.common_utils import TestCase, run_tests
16+
17+
from torchao.quantization.utils import compute_error
18+
from torchao.utils import torch_version_at_least
19+
20+
if torch_version_at_least("2.11.0"):
21+
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
22+
23+
if _is_hopper() and _is_fa3_available():
24+
from torch.nn.attention import (
25+
activate_flash_attention_impl,
26+
restore_flash_attention_impl,
27+
)
28+
29+
from torchao.prototype.attention import (
30+
AttentionBackend,
31+
apply_low_precision_attention,
32+
)
33+
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
34+
35+
36+
class SimpleAttentionModel(nn.Module):
37+
def __init__(self, embed_dim, num_heads):
38+
super().__init__()
39+
self.num_heads = num_heads
40+
self.head_dim = embed_dim // num_heads
41+
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
42+
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
43+
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
44+
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
45+
46+
def forward(self, x):
47+
B, S, _ = x.shape
48+
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
49+
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
50+
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
51+
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
52+
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))
53+
54+
55+
@common_utils.instantiate_parametrized_tests
56+
class TestFP8FA3Attention(TestCase):
57+
@unittest.skipUnless(
58+
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
59+
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
60+
)
61+
@common_utils.parametrize("shape", [(2, 8, 1024, 64), (1, 16, 1024, 128)])
62+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
63+
def test_sdpa_accuracy(self, shape, dtype):
64+
B, H, S, D = shape
65+
q = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
66+
k = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
67+
v = torch.randn(B, H, S, D, device="cuda", dtype=dtype)
68+
69+
with torch.no_grad():
70+
out_ref = F.scaled_dot_product_attention(q, k, v, is_causal=False)
71+
72+
activate_flash_attention_impl("FA3")
73+
try:
74+
with torch.no_grad():
75+
out_fp8 = fp8_fa3_sdpa(q, k, v, is_causal=False)
76+
finally:
77+
restore_flash_attention_impl()
78+
79+
sqnr = compute_error(out_ref, out_fp8)
80+
self.assertGreater(
81+
sqnr.item(),
82+
25.0,
83+
f"SQNR {sqnr.item():.2f} dB below 25 dB for shape={shape}, dtype={dtype}",
84+
)
85+
86+
@unittest.skipUnless(
87+
torch_version_at_least("2.11.0") and _is_hopper() and _is_fa3_available(),
88+
"Requires PyTorch >= 2.11, Hopper GPU, and FA3",
89+
)
90+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
91+
def test_monkey_patch_model(self, dtype):
92+
embed_dim, num_heads = 512, 8
93+
model = (
94+
SimpleAttentionModel(embed_dim, num_heads)
95+
.to(device="cuda", dtype=dtype)
96+
.eval()
97+
)
98+
x = torch.randn(2, 128, embed_dim, device="cuda", dtype=dtype)
99+
100+
with torch.no_grad():
101+
out_ref = model(x)
102+
103+
fp8_model = (
104+
SimpleAttentionModel(embed_dim, num_heads)
105+
.to(device="cuda", dtype=dtype)
106+
.eval()
107+
)
108+
fp8_model.load_state_dict(model.state_dict())
109+
fp8_model = apply_low_precision_attention(
110+
fp8_model,
111+
backend=AttentionBackend.FP8_FA3,
112+
fuse_rope_using_torch_compile=False,
113+
)
114+
115+
with torch.no_grad():
116+
out_fp8 = fp8_model(x)
117+
118+
sqnr = compute_error(out_ref, out_fp8)
119+
self.assertGreater(
120+
sqnr.item(),
121+
20.0,
122+
f"SQNR {sqnr.item():.2f} dB below 20 dB for dtype={dtype}",
123+
)
124+
125+
126+
if __name__ == "__main__":
127+
run_tests()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Low-precision attention for inference.
9+
10+
Only supports forward pass — backward is not supported by the underlying backends.
11+
"""
12+
13+
from torchao.prototype.attention.api import (
14+
AttentionBackend,
15+
apply_low_precision_attention,
16+
)
17+
18+
__all__ = [
19+
"AttentionBackend",
20+
"apply_low_precision_attention",
21+
]

torchao/prototype/attention/api.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""User-facing API for low-precision attention."""
8+
9+
from enum import Enum
10+
from typing import Optional
11+
12+
import torch
13+
import torch._dynamo
14+
import torch.nn as nn
15+
16+
from torchao.prototype.attention.utils import _is_fa3_available, _is_hopper
17+
from torchao.utils import torch_version_at_least
18+
19+
if torch_version_at_least("2.11.0"):
20+
from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
21+
from torchao.prototype.attention.shared_utils.wrapper import (
22+
_LowPrecisionAttentionWrapper,
23+
)
24+
else:
25+
raise ImportError("Low-precision attention requires PyTorch 2.11+.")
26+
27+
28+
class AttentionBackend(str, Enum):
29+
"""Backend kernel for computing attention."""
30+
31+
FP8_FA3 = "FP8_FA3" # Requires SM90+ (Hopper)
32+
33+
34+
def _get_available_backend() -> AttentionBackend:
35+
if not torch.cuda.is_available():
36+
raise RuntimeError("Low-precision attention requires CUDA.")
37+
capability = torch.cuda.get_device_capability()
38+
if _is_hopper() and _is_fa3_available():
39+
return AttentionBackend.FP8_FA3
40+
raise RuntimeError(f"No compatible backend for SM{capability[0]}{capability[1]}.")
41+
42+
43+
def _check_backend_available(backend: AttentionBackend) -> None:
44+
if not torch.cuda.is_available():
45+
raise RuntimeError(f"{backend} backend requires CUDA.")
46+
capability = torch.cuda.get_device_capability()
47+
if backend == AttentionBackend.FP8_FA3:
48+
if not _is_hopper():
49+
raise RuntimeError(
50+
f"FP8_FA3 requires Hopper (SM 9.x), got SM{capability[0]}{capability[1]}."
51+
)
52+
if not _is_fa3_available():
53+
raise RuntimeError(
54+
"FP8_FA3 requires the flash-attn package with FA3 support."
55+
)
56+
else:
57+
raise ValueError(f"Unknown backend: {backend}")
58+
59+
60+
def apply_low_precision_attention(
61+
model: nn.Module,
62+
backend: Optional[AttentionBackend] = None,
63+
fuse_rope_using_torch_compile: bool = False,
64+
) -> nn.Module:
65+
"""Apply low-precision attention to a model.
66+
67+
Must be called before ``torch.compile``. KV caching should be
68+
disabled before calling (e.g., ``config.use_cache = False`` for
69+
HuggingFace models).
70+
"""
71+
if isinstance(model, _LowPrecisionAttentionWrapper):
72+
raise RuntimeError(
73+
"apply_low_precision_attention has already been applied to this module."
74+
)
75+
if isinstance(model, torch._dynamo.OptimizedModule):
76+
raise RuntimeError(
77+
"apply_low_precision_attention must be called before torch.compile."
78+
)
79+
80+
if backend is None:
81+
backend = _get_available_backend()
82+
else:
83+
_check_backend_available(backend)
84+
85+
if backend == AttentionBackend.FP8_FA3:
86+
return setup_fp8_backend(model, "FA3", fuse_rope_using_torch_compile)
87+
88+
raise ValueError(f"Unknown backend: {backend}")
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""FP8 attention using FA3 backend."""
8+
9+
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
10+
from torchao.prototype.attention.quantization import _fp8_sdpa_quantize
11+
12+
__all__ = [
13+
"fp8_fa3_sdpa",
14+
"_fp8_sdpa_quantize",
15+
]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""FP8 SDPA using FA3 backend.
8+
9+
Thin wrapper around ``shared_utils/attention.py``. When using directly,
10+
activate the FA3 flash attention implementation before calling.
11+
"""
12+
13+
from functools import partial
14+
15+
from torchao.prototype.attention.shared_utils.attention import (
16+
_fp8_sdpa,
17+
)
18+
19+
fp8_fa3_sdpa = partial(_fp8_sdpa, backend_name="FA3")
20+
fp8_fa3_sdpa.__doc__ = _fp8_sdpa.__doc__
21+
fp8_fa3_sdpa.__name__ = "fp8_fa3_sdpa"
22+
fp8_fa3_sdpa.__qualname__ = "fp8_fa3_sdpa"
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""FP8 FA3 backend setup."""
8+
9+
import torch.nn as nn
10+
11+
from torchao.prototype.attention.config import LowPrecisionAttentionConfig
12+
from torchao.prototype.attention.shared_utils.setup import setup_fp8_backend
13+
14+
15+
def setup_fp8_fa3(
16+
model: nn.Module,
17+
config: LowPrecisionAttentionConfig,
18+
) -> nn.Module:
19+
"""Set up FP8 FA3 attention on *model* and wrap it."""
20+
from torchao.prototype.attention.fp8_fa3.attention import fp8_fa3_sdpa
21+
22+
return setup_fp8_backend(
23+
model,
24+
config,
25+
flash_impl_name="FA3",
26+
sdpa_fn=fp8_fa3_sdpa,
27+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Shared FP8 quantization kernels for low-precision attention."""
8+
9+
from torchao.prototype.attention.quantization.quantization import (
10+
_fp8_sdpa_quantize,
11+
)
12+
13+
__all__ = [
14+
"_fp8_sdpa_quantize",
15+
]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
FP8 quantization for attention inputs.
9+
"""
10+
11+
from typing import Tuple
12+
13+
import torch
14+
15+
16+
def _fp8_sdpa_quantize(
17+
q: torch.Tensor,
18+
k: torch.Tensor,
19+
v: torch.Tensor,
20+
) -> Tuple[
21+
torch.Tensor,
22+
torch.Tensor,
23+
torch.Tensor,
24+
torch.Tensor,
25+
torch.Tensor,
26+
torch.Tensor,
27+
]:
28+
"""Quantize Q, K, V to FP8 with per-head scaling."""
29+
if q.dim() != 4:
30+
raise ValueError(f"Expected 4D tensor for q, got {q.dim()}D")
31+
if k.dim() != 4:
32+
raise ValueError(f"Expected 4D tensor for k, got {k.dim()}D")
33+
if v.dim() != 4:
34+
raise ValueError(f"Expected 4D tensor for v, got {v.dim()}D")
35+
if k.shape != v.shape:
36+
raise ValueError(f"K and V shape mismatch: {k.shape} vs {v.shape}")
37+
if q.shape[0] != k.shape[0]:
38+
raise ValueError(f"Batch size mismatch: {q.shape[0]} vs {k.shape[0]}")
39+
if q.shape[1] % k.shape[1] != 0:
40+
raise ValueError(
41+
f"Q head count ({q.shape[1]}) must be a multiple of K head count ({k.shape[1]})"
42+
)
43+
if q.shape[3] != k.shape[3]:
44+
raise ValueError(f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}")
45+
46+
from torchao.prototype.attention.quantization.triton_qkv_quantization import (
47+
triton_fp8_sdpa_quantize,
48+
)
49+
50+
return triton_fp8_sdpa_quantize(q, k, v)

0 commit comments

Comments
 (0)