1616
1717import argparse
1818import gc
19+ from contextlib import nullcontext
1920
2021import torch
2122import torch ._dynamo
2425from lm_eval .models .huggingface import HFLM
2526from torch ._inductor .compile_fx import compile_fx
2627from torch .nn .attention import (
28+ SDPBackend ,
2729 activate_flash_attention_impl ,
2830 restore_flash_attention_impl ,
31+ sdpa_kernel ,
2932)
3033from transformers import AutoModelForCausalLM , AutoTokenizer
3134
4649 "flash_impl" : None ,
4750 "fp8" : False ,
4851 "label" : "FA2 BF16" ,
52+ "sdpa_backend" : SDPBackend .FLASH_ATTENTION ,
4953 },
5054 "fa3" : {
5155 "flash_impl" : "FA3" ,
5256 "fp8" : False ,
5357 "label" : "FA3 BF16" ,
58+ "sdpa_backend" : SDPBackend .FLASH_ATTENTION ,
5459 },
5560 "fa3_fp8" : {
5661 "flash_impl" : "FA3" ,
@@ -120,8 +125,9 @@ def mask_strip_backend(gm, example_inputs):
120125
121126
122127def setup_backend (orig_model , backend_name , compile_flag ):
123- """Set up a backend and return (model, flash_impl)."""
128+ """Set up a backend and return (model, flash_impl, sdpa_backend )."""
124129 cfg = BACKENDS [backend_name ]
130+ sdpa_backend = cfg .get ("sdpa_backend" )
125131
126132 if cfg ["fp8" ]:
127133 print (f" Applying low-precision FP8 attention ({ backend_name } )..." )
@@ -136,7 +142,7 @@ def setup_backend(orig_model, backend_name, compile_flag):
136142 if compile_flag :
137143 print (f" Compiling model with torch.compile ({ backend_name } )..." )
138144 model = torch .compile (model )
139- return model , cfg ["flash_impl" ]
145+ return model , cfg ["flash_impl" ], sdpa_backend
140146 else :
141147 if compile_flag :
142148 print (f" Compiling model with torch.compile ({ backend_name } )..." )
@@ -146,22 +152,24 @@ def setup_backend(orig_model, backend_name, compile_flag):
146152 model = _compile_with_mask_strip (
147153 orig_model , flash_impl_name = cfg ["flash_impl" ]
148154 )
149- return model , cfg ["flash_impl" ]
155+ return model , cfg ["flash_impl" ], sdpa_backend
150156 # Restore use_cache in case a prior setup disabled it.
151157 orig_model .config .use_cache = True
152- return orig_model , cfg ["flash_impl" ]
158+ return orig_model , cfg ["flash_impl" ], sdpa_backend
153159
154160
155- def evaluate_perplexity (model , tokenizer , flash_impl ) -> float :
161+ def evaluate_perplexity (model , tokenizer , flash_impl , sdpa_backend = None ) -> float :
156162 # Evaluate perplexity on WikiText-2 using lm_eval.
157163 if flash_impl :
158164 activate_flash_attention_impl (flash_impl )
165+ ctx = sdpa_kernel (sdpa_backend ) if sdpa_backend is not None else nullcontext ()
159166 try :
160- results = evaluator .simple_evaluate (
161- HFLM (pretrained = model , tokenizer = tokenizer ),
162- tasks = ["wikitext" ],
163- batch_size = 1 ,
164- )
167+ with ctx :
168+ results = evaluator .simple_evaluate (
169+ HFLM (pretrained = model , tokenizer = tokenizer ),
170+ tasks = ["wikitext" ],
171+ batch_size = 1 ,
172+ )
165173 finally :
166174 if flash_impl :
167175 restore_flash_attention_impl ()
@@ -178,26 +186,33 @@ def benchmark_runtime(
178186 flash_impl ,
179187 num_warmup ,
180188 num_iters ,
189+ sdpa_backend = None ,
181190) -> float :
182191 """Benchmark forward-pass latency at a given sequence length. Returns median ms."""
183192 input_ids = torch .randint (0 , vocab_size , (1 , seq_len ), device = device )
184193
185194 if flash_impl :
186195 activate_flash_attention_impl (flash_impl )
196+ ctx = sdpa_kernel (sdpa_backend ) if sdpa_backend is not None else nullcontext ()
187197 try :
188- # Warmup
189- for _ in range (num_warmup ):
190- model (input_ids )
191- torch .cuda .synchronize ()
192-
193- start_events = [torch .cuda .Event (enable_timing = True ) for _ in range (num_iters )]
194- end_events = [torch .cuda .Event (enable_timing = True ) for _ in range (num_iters )]
195-
196- for i in range (num_iters ):
197- start_events [i ].record ()
198- model (input_ids )
199- end_events [i ].record ()
200- torch .cuda .synchronize ()
198+ with ctx :
199+ # Warmup
200+ for _ in range (num_warmup ):
201+ model (input_ids )
202+ torch .cuda .synchronize ()
203+
204+ start_events = [
205+ torch .cuda .Event (enable_timing = True ) for _ in range (num_iters )
206+ ]
207+ end_events = [
208+ torch .cuda .Event (enable_timing = True ) for _ in range (num_iters )
209+ ]
210+
211+ for i in range (num_iters ):
212+ start_events [i ].record ()
213+ model (input_ids )
214+ end_events [i ].record ()
215+ torch .cuda .synchronize ()
201216 finally :
202217 if flash_impl :
203218 restore_flash_attention_impl ()
@@ -258,22 +273,24 @@ def run_benchmark(
258273
259274 # --- Baseline perplexity ---
260275 print (f"\n Computing perplexity with { baseline_label } ..." )
261- baseline_model , baseline_flash = setup_backend (
276+ baseline_model , baseline_flash , baseline_sdpa = setup_backend (
262277 orig_model ,
263278 baseline_backend ,
264279 compile ,
265280 )
266- baseline_ppl = evaluate_perplexity (baseline_model , tokenizer , baseline_flash )
281+ baseline_ppl = evaluate_perplexity (
282+ baseline_model , tokenizer , baseline_flash , baseline_sdpa
283+ )
267284 print (f" { baseline_label } perplexity: { baseline_ppl :.2f} " )
268285
269286 # --- Test perplexity ---
270287 print (f"\n Computing perplexity with { test_label } ..." )
271- test_model , test_flash = setup_backend (
288+ test_model , test_flash , test_sdpa = setup_backend (
272289 orig_model ,
273290 test_backend ,
274291 compile ,
275292 )
276- test_ppl = evaluate_perplexity (test_model , tokenizer , test_flash )
293+ test_ppl = evaluate_perplexity (test_model , tokenizer , test_flash , test_sdpa )
277294 print (f" { test_label } perplexity: { test_ppl :.2f} " )
278295
279296 print (f"\n Delta: { test_ppl - baseline_ppl :+.2f} " )
@@ -289,7 +306,7 @@ def run_benchmark(
289306
290307 # --- Baseline runtime (all sequence lengths) ---
291308 print (f"\n Running baseline ({ baseline_label } )..." )
292- baseline_model , baseline_flash = setup_backend (
309+ baseline_model , baseline_flash , baseline_sdpa = setup_backend (
293310 orig_model ,
294311 baseline_backend ,
295312 compile ,
@@ -305,6 +322,7 @@ def run_benchmark(
305322 baseline_flash ,
306323 num_warmup ,
307324 num_runtime_iters ,
325+ sdpa_backend = baseline_sdpa ,
308326 )
309327 baseline_runtimes [S ] = ms
310328 print (f" seq_len={ S :>6} : { ms :.1f} ms" )
@@ -319,7 +337,7 @@ def run_benchmark(
319337
320338 # --- Test runtime (all sequence lengths) ---
321339 print (f"\n Running test ({ test_label } )..." )
322- test_model , test_flash = setup_backend (
340+ test_model , test_flash , test_sdpa = setup_backend (
323341 orig_model ,
324342 test_backend ,
325343 compile ,
@@ -335,6 +353,7 @@ def run_benchmark(
335353 test_flash ,
336354 num_warmup ,
337355 num_runtime_iters ,
356+ sdpa_backend = test_sdpa ,
338357 )
339358 test_runtimes [S ] = ms
340359 print (f" seq_len={ S :>6} : { ms :.1f} ms" )
0 commit comments