Skip to content

Commit 1da0e41

Browse files
author
Copilot
committed
Merge branch 'main' into cpu-multi-isa-dispatch
2 parents 8e9f4c9 + 92dcc96 commit 1da0e41

14 files changed

Lines changed: 637 additions & 78 deletions

File tree

benchmarks/prototype/attention/benchmark_sdpa.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,19 @@ def _activate_backend(backend: str):
5656
"fa3_fp8_hadamard_v": "V_ONLY",
5757
}
5858

59+
_SDPA_BACKEND = {
60+
"fa2": SDPBackend.FLASH_ATTENTION,
61+
"fa3": SDPBackend.FLASH_ATTENTION,
62+
}
63+
5964

6065
def _run_attention(backend: str, q, k, v, is_causal: bool):
6166
"""Run a single attention call for the given backend."""
6267
if backend in _HADAMARD_MODE:
6368
return fp8_fa3_sdpa(
6469
q, k, v, is_causal=is_causal, hadamard=_HADAMARD_MODE[backend]
6570
)
66-
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
71+
with sdpa_kernel(_SDPA_BACKEND[backend]):
6772
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
6873

6974

benchmarks/prototype/attention/eval_flux_model.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
from diffusers import FluxPipeline
2727
from PIL import Image
2828
from torch.nn.attention import (
29+
SDPBackend,
2930
activate_flash_attention_impl,
3031
restore_flash_attention_impl,
32+
sdpa_kernel,
3133
)
3234

3335
from torchao.prototype.attention import (
@@ -37,8 +39,16 @@
3739
)
3840

3941
BACKENDS = {
40-
"fa2": {"flash_impl": None, "fp8": False},
41-
"fa3": {"flash_impl": "FA3", "fp8": False},
42+
"fa2": {
43+
"flash_impl": None,
44+
"fp8": False,
45+
"sdpa_backend": SDPBackend.FLASH_ATTENTION,
46+
},
47+
"fa3": {
48+
"flash_impl": "FA3",
49+
"fp8": False,
50+
"sdpa_backend": SDPBackend.FLASH_ATTENTION,
51+
},
4252
"fa3_fp8": {
4353
"flash_impl": "FA3",
4454
"fp8": True,
@@ -76,7 +86,7 @@ def setup_backend(
7686
compile_flag,
7787
orig_transformer,
7888
):
79-
"""Set up a backend and return the flash_impl name."""
89+
"""Set up a backend and return (flash_impl, sdpa_backend)."""
8090
cfg = BACKENDS[backend_name]
8191
pipe.transformer = orig_transformer
8292

@@ -90,12 +100,12 @@ def setup_backend(
90100
if compile_flag:
91101
print(f"Compiling transformer with torch.compile ({backend_name})...")
92102
pipe.transformer = torch.compile(pipe.transformer)
93-
return cfg["flash_impl"]
103+
return cfg["flash_impl"], None
94104
else:
95105
if compile_flag:
96106
print(f"Compiling transformer with torch.compile ({backend_name})...")
97107
pipe.transformer = torch.compile(pipe.transformer)
98-
return cfg["flash_impl"]
108+
return cfg["flash_impl"], cfg.get("sdpa_backend")
99109

100110

101111
def pil_to_lpips_tensor(img: Image.Image, device: str) -> torch.Tensor:
@@ -124,10 +134,25 @@ def generate_image(
124134
height: int = 2048,
125135
width: int = 2048,
126136
flash_impl: Optional[str] = None,
137+
sdpa_backend: Optional[SDPBackend] = None,
127138
) -> Image.Image:
128139
"""Generate an image from a prompt with deterministic seed."""
129140
generator = torch.Generator(device=device).manual_seed(seed)
130141

142+
# For BF16 backends, force the correct SDPA backend on the transformer
143+
# only (not the VAE, whose head_dim=512 exceeds flash/cuDNN limits and
144+
# needs the math backend). FP8 backends call their ops directly and
145+
# don't need this.
146+
orig_forward = None
147+
if sdpa_backend is not None:
148+
orig_forward = pipe.transformer.forward
149+
150+
def _forced_backend_forward(*args, **kwargs):
151+
with sdpa_kernel(sdpa_backend):
152+
return orig_forward(*args, **kwargs)
153+
154+
pipe.transformer.forward = _forced_backend_forward
155+
131156
if flash_impl:
132157
activate_flash_attention_impl(flash_impl)
133158
try:
@@ -140,6 +165,8 @@ def generate_image(
140165
generator=generator,
141166
).images[0]
142167
finally:
168+
if orig_forward is not None:
169+
pipe.transformer.forward = orig_forward
143170
if flash_impl:
144171
restore_flash_attention_impl()
145172

@@ -211,7 +238,7 @@ def run_benchmark(
211238
print(f"Phase 1: Generating images ({baseline_backend})")
212239
print("-" * 80)
213240

214-
baseline_flash_impl = setup_backend(
241+
baseline_flash_impl, baseline_sdpa = setup_backend(
215242
pipe,
216243
baseline_backend,
217244
compile,
@@ -230,6 +257,7 @@ def run_benchmark(
230257
height=height,
231258
width=width,
232259
flash_impl=baseline_flash_impl,
260+
sdpa_backend=baseline_sdpa,
233261
)
234262
print(f" Warmup {i + 1}/{warmup_iters} complete")
235263

@@ -251,6 +279,7 @@ def run_benchmark(
251279
height=height,
252280
width=width,
253281
flash_impl=baseline_flash_impl,
282+
sdpa_backend=baseline_sdpa,
254283
)
255284
end_event.record()
256285
torch.cuda.synchronize()
@@ -274,7 +303,7 @@ def run_benchmark(
274303
print(f"Phase 2: Generating images ({test_backend})")
275304
print("-" * 80)
276305

277-
test_flash_impl = setup_backend(
306+
test_flash_impl, test_sdpa = setup_backend(
278307
pipe,
279308
test_backend,
280309
compile,
@@ -292,6 +321,7 @@ def run_benchmark(
292321
height=height,
293322
width=width,
294323
flash_impl=test_flash_impl,
324+
sdpa_backend=test_sdpa,
295325
)
296326
print(f" Warmup {i + 1}/{warmup_iters} complete")
297327

@@ -314,6 +344,7 @@ def run_benchmark(
314344
height=height,
315345
width=width,
316346
flash_impl=test_flash_impl,
347+
sdpa_backend=test_sdpa,
317348
)
318349
end_event.record()
319350
torch.cuda.synchronize()

benchmarks/prototype/attention/eval_llama3_model.py

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import argparse
1818
import gc
19+
from contextlib import nullcontext
1920

2021
import torch
2122
import torch._dynamo
@@ -24,8 +25,10 @@
2425
from lm_eval.models.huggingface import HFLM
2526
from torch._inductor.compile_fx import compile_fx
2627
from torch.nn.attention import (
28+
SDPBackend,
2729
activate_flash_attention_impl,
2830
restore_flash_attention_impl,
31+
sdpa_kernel,
2932
)
3033
from transformers import AutoModelForCausalLM, AutoTokenizer
3134

@@ -46,11 +49,13 @@
4649
"flash_impl": None,
4750
"fp8": False,
4851
"label": "FA2 BF16",
52+
"sdpa_backend": SDPBackend.FLASH_ATTENTION,
4953
},
5054
"fa3": {
5155
"flash_impl": "FA3",
5256
"fp8": False,
5357
"label": "FA3 BF16",
58+
"sdpa_backend": SDPBackend.FLASH_ATTENTION,
5459
},
5560
"fa3_fp8": {
5661
"flash_impl": "FA3",
@@ -120,8 +125,9 @@ def mask_strip_backend(gm, example_inputs):
120125

121126

122127
def setup_backend(orig_model, backend_name, compile_flag):
123-
"""Set up a backend and return (model, flash_impl)."""
128+
"""Set up a backend and return (model, flash_impl, sdpa_backend)."""
124129
cfg = BACKENDS[backend_name]
130+
sdpa_backend = cfg.get("sdpa_backend")
125131

126132
if cfg["fp8"]:
127133
print(f" Applying low-precision FP8 attention ({backend_name})...")
@@ -136,7 +142,7 @@ def setup_backend(orig_model, backend_name, compile_flag):
136142
if compile_flag:
137143
print(f" Compiling model with torch.compile ({backend_name})...")
138144
model = torch.compile(model)
139-
return model, cfg["flash_impl"]
145+
return model, cfg["flash_impl"], sdpa_backend
140146
else:
141147
if compile_flag:
142148
print(f" Compiling model with torch.compile ({backend_name})...")
@@ -146,22 +152,24 @@ def setup_backend(orig_model, backend_name, compile_flag):
146152
model = _compile_with_mask_strip(
147153
orig_model, flash_impl_name=cfg["flash_impl"]
148154
)
149-
return model, cfg["flash_impl"]
155+
return model, cfg["flash_impl"], sdpa_backend
150156
# Restore use_cache in case a prior setup disabled it.
151157
orig_model.config.use_cache = True
152-
return orig_model, cfg["flash_impl"]
158+
return orig_model, cfg["flash_impl"], sdpa_backend
153159

154160

155-
def evaluate_perplexity(model, tokenizer, flash_impl) -> float:
161+
def evaluate_perplexity(model, tokenizer, flash_impl, sdpa_backend=None) -> float:
156162
# Evaluate perplexity on WikiText-2 using lm_eval.
157163
if flash_impl:
158164
activate_flash_attention_impl(flash_impl)
165+
ctx = sdpa_kernel(sdpa_backend) if sdpa_backend is not None else nullcontext()
159166
try:
160-
results = evaluator.simple_evaluate(
161-
HFLM(pretrained=model, tokenizer=tokenizer),
162-
tasks=["wikitext"],
163-
batch_size=1,
164-
)
167+
with ctx:
168+
results = evaluator.simple_evaluate(
169+
HFLM(pretrained=model, tokenizer=tokenizer),
170+
tasks=["wikitext"],
171+
batch_size=1,
172+
)
165173
finally:
166174
if flash_impl:
167175
restore_flash_attention_impl()
@@ -178,26 +186,33 @@ def benchmark_runtime(
178186
flash_impl,
179187
num_warmup,
180188
num_iters,
189+
sdpa_backend=None,
181190
) -> float:
182191
"""Benchmark forward-pass latency at a given sequence length. Returns median ms."""
183192
input_ids = torch.randint(0, vocab_size, (1, seq_len), device=device)
184193

185194
if flash_impl:
186195
activate_flash_attention_impl(flash_impl)
196+
ctx = sdpa_kernel(sdpa_backend) if sdpa_backend is not None else nullcontext()
187197
try:
188-
# Warmup
189-
for _ in range(num_warmup):
190-
model(input_ids)
191-
torch.cuda.synchronize()
192-
193-
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
194-
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iters)]
195-
196-
for i in range(num_iters):
197-
start_events[i].record()
198-
model(input_ids)
199-
end_events[i].record()
200-
torch.cuda.synchronize()
198+
with ctx:
199+
# Warmup
200+
for _ in range(num_warmup):
201+
model(input_ids)
202+
torch.cuda.synchronize()
203+
204+
start_events = [
205+
torch.cuda.Event(enable_timing=True) for _ in range(num_iters)
206+
]
207+
end_events = [
208+
torch.cuda.Event(enable_timing=True) for _ in range(num_iters)
209+
]
210+
211+
for i in range(num_iters):
212+
start_events[i].record()
213+
model(input_ids)
214+
end_events[i].record()
215+
torch.cuda.synchronize()
201216
finally:
202217
if flash_impl:
203218
restore_flash_attention_impl()
@@ -258,22 +273,24 @@ def run_benchmark(
258273

259274
# --- Baseline perplexity ---
260275
print(f"\n Computing perplexity with {baseline_label}...")
261-
baseline_model, baseline_flash = setup_backend(
276+
baseline_model, baseline_flash, baseline_sdpa = setup_backend(
262277
orig_model,
263278
baseline_backend,
264279
compile,
265280
)
266-
baseline_ppl = evaluate_perplexity(baseline_model, tokenizer, baseline_flash)
281+
baseline_ppl = evaluate_perplexity(
282+
baseline_model, tokenizer, baseline_flash, baseline_sdpa
283+
)
267284
print(f" {baseline_label} perplexity: {baseline_ppl:.2f}")
268285

269286
# --- Test perplexity ---
270287
print(f"\n Computing perplexity with {test_label}...")
271-
test_model, test_flash = setup_backend(
288+
test_model, test_flash, test_sdpa = setup_backend(
272289
orig_model,
273290
test_backend,
274291
compile,
275292
)
276-
test_ppl = evaluate_perplexity(test_model, tokenizer, test_flash)
293+
test_ppl = evaluate_perplexity(test_model, tokenizer, test_flash, test_sdpa)
277294
print(f" {test_label} perplexity: {test_ppl:.2f}")
278295

279296
print(f"\n Delta: {test_ppl - baseline_ppl:+.2f}")
@@ -289,7 +306,7 @@ def run_benchmark(
289306

290307
# --- Baseline runtime (all sequence lengths) ---
291308
print(f"\n Running baseline ({baseline_label})...")
292-
baseline_model, baseline_flash = setup_backend(
309+
baseline_model, baseline_flash, baseline_sdpa = setup_backend(
293310
orig_model,
294311
baseline_backend,
295312
compile,
@@ -305,6 +322,7 @@ def run_benchmark(
305322
baseline_flash,
306323
num_warmup,
307324
num_runtime_iters,
325+
sdpa_backend=baseline_sdpa,
308326
)
309327
baseline_runtimes[S] = ms
310328
print(f" seq_len={S:>6}: {ms:.1f} ms")
@@ -319,7 +337,7 @@ def run_benchmark(
319337

320338
# --- Test runtime (all sequence lengths) ---
321339
print(f"\n Running test ({test_label})...")
322-
test_model, test_flash = setup_backend(
340+
test_model, test_flash, test_sdpa = setup_backend(
323341
orig_model,
324342
test_backend,
325343
compile,
@@ -335,6 +353,7 @@ def run_benchmark(
335353
test_flash,
336354
num_warmup,
337355
num_runtime_iters,
356+
sdpa_backend=test_sdpa,
338357
)
339358
test_runtimes[S] = ms
340359
print(f" seq_len={S:>6}: {ms:.1f} ms")

test/prototype/moe_training/test_kernels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _is_sm_10x() -> bool:
6969

7070

7171
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
72+
@skip_if_rocm("jagged rowwise scales kernel vs torch reference mismatch on ROCm")
7273
def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
7374
# Tests case where rowwise scales are computed for multiple distinct subtensors,
7475
# with end boundary of each group is determine by their end column indexes (offsets).

0 commit comments

Comments
 (0)