Skip to content

Commit 0a229a9

Browse files
committed
clean up flux-1.schnell benchmark and add to docs
Summary: Combining various improvements to the flux-1.schnell benchmark in a single PR: * set `num_inference_steps` to 4 to match the default for this model * turn on cuda graphs via `reduce-overhead` to improve nvfp4 performance for batch size 1, this required patching the transformer blocks forward using code I got from jbschlosser * add a larger batch size for additional performance metrics at larger shapes * remove torch.compile from vae, as compile times are too long and the goal of this benchmark is to compare recipes and track performance improvements, not achieve best possible latency. Now perf results for a single configuration take ~30s - ~1 min, and overall run on one B200 is ~20 mins. * fix quant recipe to exclude one more layer with small shapes * add the results to documentation I want to land this so we can see the e2e metrics lift from eventually landing #4031 Test Plan: ``` ./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh ``` ghstack-source-id: 8fb74eb ghstack-comment-id: 4055040194 Pull-Request: #4072
1 parent 3c2cb8c commit 0a229a9

3 files changed

Lines changed: 137 additions & 30 deletions

File tree

benchmarks/quantization/eval_accuracy_and_perf_of_flux.py

Lines changed: 100 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import os
99
import random
1010
import time
11+
from functools import wraps
12+
from typing import Callable, TypeVar
1113

1214
import diffusers
1315
import fire
@@ -25,6 +27,9 @@
2527
quantize_,
2628
)
2729

30+
# Type variables for better type hinting
31+
T = TypeVar("T")
32+
2833
# -----------------------------
2934
# Config
3035
# -----------------------------
@@ -71,12 +76,18 @@ def print_pipeline_architecture(pipe):
7176

7277

7378
def generate_image(
74-
pipe, prompt: str, seed: int, device: str, num_inference_steps: int
79+
pipe,
80+
prompt: str,
81+
seed: int,
82+
device: str,
83+
num_inference_steps: int,
84+
batch_size: int = 1,
7585
) -> Image.Image:
7686
generator = torch.Generator(device=device).manual_seed(seed)
7787

88+
prompts = [prompt] * batch_size
7889
image = pipe(
79-
prompt=prompt,
90+
prompt=prompts,
8091
num_inference_steps=num_inference_steps, # can tweak for speed vs quality
8192
guidance_scale=7.5,
8293
generator=generator,
@@ -238,18 +249,58 @@ def pil_to_lpips_tensor(img: Image.Image, device: str):
238249
return t.to(device)
239250

240251

252+
from torch.utils._pytree import tree_map_only
253+
254+
255+
def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
256+
"""
257+
Clone the CUDA output tensors of a function to avoid in-place operations.
258+
259+
This wrapper is useful when working with torch.compile to prevent errors
260+
related to in-place operations on tensors.
261+
262+
Args:
263+
f: The function whose CUDA tensor outputs should be cloned
264+
265+
Returns:
266+
A wrapped function that clones any CUDA tensor outputs
267+
"""
268+
269+
@wraps(f)
270+
def wrapped(*args, **kwargs):
271+
outputs = f(*args, **kwargs)
272+
return tree_map_only(
273+
torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
274+
)
275+
276+
return wrapped
277+
278+
279+
def apply_torch_compile(pipe, torch_compile_mode: str = "default"):
280+
"""Apply torch.compile to the transformer blocks in-place."""
281+
for block in pipe.transformer.transformer_blocks:
282+
block.forward = clone_output_wrapper(
283+
torch.compile(block.forward, mode=torch_compile_mode)
284+
)
285+
for block in pipe.transformer.single_transformer_blocks:
286+
block.forward = clone_output_wrapper(
287+
torch.compile(block.forward, mode=torch_compile_mode)
288+
)
289+
290+
241291
@torch.inference_mode()
242292
def run(
243293
mode: str = "accuracy",
244294
num_prompts: int = None,
245-
num_inference_steps: int = 20,
295+
num_inference_steps: int = 4,
246296
quant_config_str: str = "float8_rowwise",
247297
use_compile: bool = False,
248298
torch_compile_mode: str = "default",
249299
debug_prompt: str | None = None,
250300
print_model: bool = False,
251301
cache_baseline_images: bool = False,
252302
perf_n_iter: int = 10,
303+
batch_size: int = 1,
253304
use_deterministic_algorithms: bool = False,
254305
num_gpus_used: int = None,
255306
):
@@ -282,6 +333,7 @@ def run(
282333
instead of regenerated, if available. This is useful to make eval runs faster
283334
if we know the baseline is not changing.
284335
perf_n_iter: number of measurements to take for measuring performance
336+
batch_size: batch size for performance_hp and performance_quant modes (default 1)
285337
use_deterministic_algorithms: if True, sets torch.use_deterministic_algorithms(True)
286338
num_gpus_used: For 'aggregate_accuracy' mode, the number of GPUs that were used
287339
to generate the data. Required for aggregate_accuracy mode.
@@ -314,6 +366,7 @@ def run(
314366
print(f"[Rank {local_rank}/{world_size}] use_compile: {use_compile}")
315367
print(f"[Rank {local_rank}/{world_size}] torch_compile_mode: {torch_compile_mode}")
316368
print(f"[Rank {local_rank}/{world_size}] {use_deterministic_algorithms=}")
369+
print(f"[Rank {local_rank}/{world_size}] {batch_size=}")
317370
print(f"[Rank {local_rank}/{world_size}] {cache_baseline_images=}")
318371

319372
assert mode in (
@@ -322,6 +375,11 @@ def run(
322375
"performance_quant",
323376
"aggregate_accuracy",
324377
)
378+
assert batch_size >= 1, f"batch_size must be >= 1, got {batch_size}"
379+
if mode in ("accuracy", "aggregate_accuracy"):
380+
assert batch_size == 1, (
381+
f"batch_size must be 1 for {mode} mode, got {batch_size}"
382+
)
325383

326384
# Handle aggregate_accuracy mode separately
327385
if mode == "aggregate_accuracy":
@@ -438,14 +496,6 @@ def run(
438496

439497
loss_fn = lpips.LPIPS(net="vgg").to(device)
440498

441-
# Store original for restoration later, since we will quantize it
442-
# and compile the quantized version again
443-
orig_transformer = pipe.transformer
444-
445-
if use_compile:
446-
pipe.transformer = torch.compile(orig_transformer, mode=torch_compile_mode)
447-
pipe.vae.decode = torch.compile(pipe.vae.decode, mode=torch_compile_mode)
448-
449499
# -----------------------------
450500
# 2. Baseline images (for all prompts)
451501
# -----------------------------
@@ -473,6 +523,8 @@ def run(
473523
baseline_times = []
474524

475525
if mode == "accuracy":
526+
# note: never compile for baseline images
527+
476528
for local_idx, prompt in enumerate(my_prompts):
477529
# Calculate global prompt index
478530
global_idx = local_rank + local_idx * world_size
@@ -500,32 +552,39 @@ def run(
500552
baseline_times.append(t1 - t0)
501553

502554
elif mode == "performance_hp":
555+
if use_compile:
556+
apply_torch_compile(pipe, torch_compile_mode)
557+
503558
# High precision performance mode - measure baseline without quantization
504559
if local_rank == 0:
505560
# warm up compile
506561
_ = generate_image(
507-
pipe, prompts_to_use[0], RANDOM_SEED, device, num_inference_steps
562+
pipe,
563+
prompts_to_use[0],
564+
RANDOM_SEED,
565+
device,
566+
num_inference_steps,
567+
batch_size=batch_size,
508568
)
509569

510570
for _ in range(perf_n_iter):
511571
t0 = time.time()
512572
_ = generate_image(
513-
pipe, prompts_to_use[0], RANDOM_SEED, device, num_inference_steps
573+
pipe,
574+
prompts_to_use[0],
575+
RANDOM_SEED,
576+
device,
577+
num_inference_steps,
578+
batch_size=batch_size,
514579
)
515580
t1 = time.time()
516581
baseline_times.append(t1 - t0)
517582

518-
if use_compile and mode in ("accuracy", "performance_quant"):
519-
print(
520-
f"[Rank {local_rank}/{world_size}] Restoring original (uncompiled) transformer before quantization"
521-
)
522-
pipe.transformer = orig_transformer
523-
524583
# Only quantize for accuracy and performance_quant modes
525584
if mode in ("accuracy", "performance_quant"):
526585
# Inspect Linear layers in main component
527586
component_linear_fqns_and_weight_shapes = []
528-
for fqn, module in orig_transformer.named_modules():
587+
for fqn, module in pipe.transformer.named_modules():
529588
if isinstance(module, torch.nn.Linear):
530589
weight_shape = module.weight.shape
531590
if print_model:
@@ -545,15 +604,21 @@ def run(
545604
continue
546605
elif fqn == "proj_out":
547606
continue
607+
elif "norm.linear" in fqn:
608+
# activations here have shape [batch_size, 3072], so
609+
# too small to see speedups from activation quantization
610+
continue
548611
elif weight_shape[0] < 1024 or weight_shape[1] < 1024:
549612
continue
550613
fqn_to_config_dict[fqn] = config_obj
551614
fqn_to_config = FqnToConfig(fqn_to_config=fqn_to_config_dict)
552615

553616
# Quantize the main component using this config
554617
quantize_(pipe.transformer, fqn_to_config, filter_fn=None)
618+
555619
if use_compile:
556-
pipe.transformer = torch.compile(pipe.transformer, mode=torch_compile_mode)
620+
apply_torch_compile(pipe, torch_compile_mode)
621+
557622
if print_model:
558623
print_pipeline_architecture(pipe)
559624

@@ -615,13 +680,23 @@ def run(
615680
if local_rank == 0:
616681
# warm up compile
617682
_ = generate_image(
618-
pipe, prompts_to_use[0], RANDOM_SEED, device, num_inference_steps
683+
pipe,
684+
prompts_to_use[0],
685+
RANDOM_SEED,
686+
device,
687+
num_inference_steps,
688+
batch_size=batch_size,
619689
)
620690

621691
for _ in range(perf_n_iter):
622692
t0 = time.time()
623693
_ = generate_image(
624-
pipe, prompts_to_use[0], RANDOM_SEED, device, num_inference_steps
694+
pipe,
695+
prompts_to_use[0],
696+
RANDOM_SEED,
697+
device,
698+
num_inference_steps,
699+
batch_size=batch_size,
625700
)
626701
t1 = time.time()
627702
times.append(t1 - t0)
@@ -691,11 +766,13 @@ def run(
691766
writer.writerow(["average_quantized_time", f"{avg_quant_time:.4f}"])
692767
elif mode == "performance_hp":
693768
writer.writerow(["perf_n_iter", perf_n_iter])
769+
writer.writerow(["batch_size", batch_size])
694770
writer.writerow(["average_time", f"{avg_time:.4f}"])
695771
for idx, val in enumerate(baseline_times):
696772
writer.writerow([f"time_{idx}", f"{val:.4f}"])
697773
elif mode == "performance_quant":
698774
writer.writerow(["perf_n_iter", perf_n_iter])
775+
writer.writerow(["batch_size", batch_size])
699776
writer.writerow(["average_time", f"{avg_time:.4f}"])
700777
for idx, val in enumerate(times):
701778
writer.writerow([f"time_{idx}", f"{val:.4f}"])
Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
#!/bin/bash
22

33
# number of local GPUs to use for accuracy eval
4-
NUM_GPUS=8
4+
NUM_GPUS=1
55

66
# float8 rowwise
77
# note: max-autotune performance is nearly identical to regular compile on b200, so skip it for now
88
time torchrun --nproc_per_node=$NUM_GPUS benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode accuracy --use_deterministic_algorithms
99
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode aggregate_accuracy --num_gpus_used $NUM_GPUS
10-
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_hp --use_compile
11-
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_quant --use_compile
10+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_hp --use_compile --torch_compile_mode reduce-overhead --batch_size 1
11+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_hp --use_compile --torch_compile_mode reduce-overhead --batch_size 4
12+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 1
13+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str float8_rowwise --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 4
1214

1315
# mxfp8
1416
time torchrun --nproc_per_node=$NUM_GPUS benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode accuracy --cache_baseline_images --use_deterministic_algorithms
1517
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode aggregate_accuracy --num_gpus_used $NUM_GPUS
16-
# time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode performance_hp --use_compile
17-
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode performance_quant --use_compile
18+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 1
19+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str mxfp8 --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 4
1820

1921
# nvfp4
2022
# note: even though we are using a triton kernel for to_nvfp4 cast, we still need
2123
# to enable compile for fast generation of the nvfp4 global scale
2224
time torchrun --nproc_per_node=$NUM_GPUS benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode accuracy --cache_baseline_images --use_deterministic_algorithms
2325
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode aggregate_accuracy --num_gpus_used $NUM_GPUS
24-
# time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode performance_hp --use_compile
25-
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode performance_quant --use_compile
26+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 1
27+
time python -u benchmarks/quantization/eval_accuracy_and_perf_of_flux.py --quant_config_str nvfp4 --mode performance_quant --use_compile --torch_compile_mode reduce-overhead --batch_size 4

docs/source/workflows/inference.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,34 @@ torchao version 0.17.0+git3075bb624
178178
4 16384 16384 16384 3.82 2.31
179179
```
180180

181+
## e2e flux-1.schnell benchmarks
182+
183+
These benchmarks compare accuracy and performance of torchao inference quantization on the
184+
[flux-1.schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) model.
185+
186+
For accuracy, we measure the [LPIPS](https://github.com/richzhang/PerceptualSimilarity) score
187+
between images generated by the quantized model and the high precision (bfloat16) baseline,
188+
averaged over the prompts from the [sayakpaul/drawbench](https://huggingface.co/datasets/sayakpaul/drawbench) dataset —
189+
lower is better, with 0 meaning identical.
190+
191+
Note that this benchmark optimizes for speed of iteration and does not represent
192+
the best possible metrics someone could achieve on this model. Instead, this is an
193+
apples-to-apples comparison intended to compare different quantization recipes at a
194+
high level, and measure performance improvements.
195+
196+
| experiment | lpips_avg | time_s_bsz_1 | speedup_bsz_1 | time_s_bsz_4 | speedup_bsz_4 |
197+
| ---------- | --------- | ------------- | -------------- | ------------- | -------------- |
198+
| bfloat16 | 0 | 0.4178 | 1.00 | 1.4914 | 1.00 |
199+
| float8_rowwise | 0.1236| 0.3455 | 1.21 | 1.1986 | 1.24 |
200+
| mxfp8 | 0.1260 | 0.3673 | 1.14 | 1.2820 | 1.16 |
201+
| nvfp4 | 0.2694 | 0.3308 | 1.26 | 1.1334 | 1.32 |
202+
203+
To reproduce, run:
204+
205+
```bash
206+
./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh
207+
```
208+
181209
## Other Available Quantization Techniques
182210

183211
### Int8DynamicActivationIntxWeightConfig Quantization

0 commit comments

Comments
 (0)