88import os
99import random
1010import time
11+ from functools import wraps
12+ from typing import Callable , TypeVar
1113
1214import diffusers
1315import fire
2527 quantize_ ,
2628)
2729
30+ # Type variables for better type hinting
31+ T = TypeVar ("T" )
32+
2833# -----------------------------
2934# Config
3035# -----------------------------
@@ -71,12 +76,18 @@ def print_pipeline_architecture(pipe):
7176
7277
7378def generate_image (
74- pipe , prompt : str , seed : int , device : str , num_inference_steps : int
79+ pipe ,
80+ prompt : str ,
81+ seed : int ,
82+ device : str ,
83+ num_inference_steps : int ,
84+ batch_size : int = 1 ,
7585) -> Image .Image :
7686 generator = torch .Generator (device = device ).manual_seed (seed )
7787
88+ prompts = [prompt ] * batch_size
7889 image = pipe (
79- prompt = prompt ,
90+ prompt = prompts ,
8091 num_inference_steps = num_inference_steps , # can tweak for speed vs quality
8192 guidance_scale = 7.5 ,
8293 generator = generator ,
@@ -238,18 +249,58 @@ def pil_to_lpips_tensor(img: Image.Image, device: str):
238249 return t .to (device )
239250
240251
252+ from torch .utils ._pytree import tree_map_only
253+
254+
255+ def clone_output_wrapper (f : Callable [..., T ]) -> Callable [..., T ]:
256+ """
257+ Clone the CUDA output tensors of a function to avoid in-place operations.
258+
259+ This wrapper is useful when working with torch.compile to prevent errors
260+ related to in-place operations on tensors.
261+
262+ Args:
263+ f: The function whose CUDA tensor outputs should be cloned
264+
265+ Returns:
266+ A wrapped function that clones any CUDA tensor outputs
267+ """
268+
269+ @wraps (f )
270+ def wrapped (* args , ** kwargs ):
271+ outputs = f (* args , ** kwargs )
272+ return tree_map_only (
273+ torch .Tensor , lambda t : t .clone () if t .is_cuda else t , outputs
274+ )
275+
276+ return wrapped
277+
278+
279+ def apply_torch_compile (pipe , torch_compile_mode : str = "default" ):
280+ """Apply torch.compile to the transformer blocks in-place."""
281+ for block in pipe .transformer .transformer_blocks :
282+ block .forward = clone_output_wrapper (
283+ torch .compile (block .forward , mode = torch_compile_mode )
284+ )
285+ for block in pipe .transformer .single_transformer_blocks :
286+ block .forward = clone_output_wrapper (
287+ torch .compile (block .forward , mode = torch_compile_mode )
288+ )
289+
290+
241291@torch .inference_mode ()
242292def run (
243293 mode : str = "accuracy" ,
244294 num_prompts : int = None ,
245- num_inference_steps : int = 20 ,
295+ num_inference_steps : int = 4 ,
246296 quant_config_str : str = "float8_rowwise" ,
247297 use_compile : bool = False ,
248298 torch_compile_mode : str = "default" ,
249299 debug_prompt : str | None = None ,
250300 print_model : bool = False ,
251301 cache_baseline_images : bool = False ,
252302 perf_n_iter : int = 10 ,
303+ batch_size : int = 1 ,
253304 use_deterministic_algorithms : bool = False ,
254305 num_gpus_used : int = None ,
255306):
@@ -282,6 +333,7 @@ def run(
282333 instead of regenerated, if available. This is useful to make eval runs faster
283334 if we know the baseline is not changing.
284335 perf_n_iter: number of measurements to take for measuring performance
336+ batch_size: batch size for performance_hp and performance_quant modes (default 1)
285337 use_deterministic_algorithms: if True, sets torch.use_deterministic_algorithms(True)
286338 num_gpus_used: For 'aggregate_accuracy' mode, the number of GPUs that were used
287339 to generate the data. Required for aggregate_accuracy mode.
@@ -314,6 +366,7 @@ def run(
314366 print (f"[Rank { local_rank } /{ world_size } ] use_compile: { use_compile } " )
315367 print (f"[Rank { local_rank } /{ world_size } ] torch_compile_mode: { torch_compile_mode } " )
316368 print (f"[Rank { local_rank } /{ world_size } ] { use_deterministic_algorithms = } " )
369+ print (f"[Rank { local_rank } /{ world_size } ] { batch_size = } " )
317370 print (f"[Rank { local_rank } /{ world_size } ] { cache_baseline_images = } " )
318371
319372 assert mode in (
@@ -322,6 +375,11 @@ def run(
322375 "performance_quant" ,
323376 "aggregate_accuracy" ,
324377 )
378+ assert batch_size >= 1 , f"batch_size must be >= 1, got { batch_size } "
379+ if mode in ("accuracy" , "aggregate_accuracy" ):
380+ assert batch_size == 1 , (
381+ f"batch_size must be 1 for { mode } mode, got { batch_size } "
382+ )
325383
326384 # Handle aggregate_accuracy mode separately
327385 if mode == "aggregate_accuracy" :
@@ -438,14 +496,6 @@ def run(
438496
439497 loss_fn = lpips .LPIPS (net = "vgg" ).to (device )
440498
441- # Store original for restoration later, since we will quantize it
442- # and compile the quantized version again
443- orig_transformer = pipe .transformer
444-
445- if use_compile :
446- pipe .transformer = torch .compile (orig_transformer , mode = torch_compile_mode )
447- pipe .vae .decode = torch .compile (pipe .vae .decode , mode = torch_compile_mode )
448-
449499 # -----------------------------
450500 # 2. Baseline images (for all prompts)
451501 # -----------------------------
@@ -473,6 +523,8 @@ def run(
473523 baseline_times = []
474524
475525 if mode == "accuracy" :
526+ # note: never compile for baseline images
527+
476528 for local_idx , prompt in enumerate (my_prompts ):
477529 # Calculate global prompt index
478530 global_idx = local_rank + local_idx * world_size
@@ -500,32 +552,39 @@ def run(
500552 baseline_times .append (t1 - t0 )
501553
502554 elif mode == "performance_hp" :
555+ if use_compile :
556+ apply_torch_compile (pipe , torch_compile_mode )
557+
503558 # High precision performance mode - measure baseline without quantization
504559 if local_rank == 0 :
505560 # warm up compile
506561 _ = generate_image (
507- pipe , prompts_to_use [0 ], RANDOM_SEED , device , num_inference_steps
562+ pipe ,
563+ prompts_to_use [0 ],
564+ RANDOM_SEED ,
565+ device ,
566+ num_inference_steps ,
567+ batch_size = batch_size ,
508568 )
509569
510570 for _ in range (perf_n_iter ):
511571 t0 = time .time ()
512572 _ = generate_image (
513- pipe , prompts_to_use [0 ], RANDOM_SEED , device , num_inference_steps
573+ pipe ,
574+ prompts_to_use [0 ],
575+ RANDOM_SEED ,
576+ device ,
577+ num_inference_steps ,
578+ batch_size = batch_size ,
514579 )
515580 t1 = time .time ()
516581 baseline_times .append (t1 - t0 )
517582
518- if use_compile and mode in ("accuracy" , "performance_quant" ):
519- print (
520- f"[Rank { local_rank } /{ world_size } ] Restoring original (uncompiled) transformer before quantization"
521- )
522- pipe .transformer = orig_transformer
523-
524583 # Only quantize for accuracy and performance_quant modes
525584 if mode in ("accuracy" , "performance_quant" ):
526585 # Inspect Linear layers in main component
527586 component_linear_fqns_and_weight_shapes = []
528- for fqn , module in orig_transformer .named_modules ():
587+ for fqn , module in pipe . transformer .named_modules ():
529588 if isinstance (module , torch .nn .Linear ):
530589 weight_shape = module .weight .shape
531590 if print_model :
@@ -545,15 +604,21 @@ def run(
545604 continue
546605 elif fqn == "proj_out" :
547606 continue
607+ elif "norm.linear" in fqn :
608+ # activations here have shape [batch_size, 3072], so
609+ # too small to see speedups from activation quantization
610+ continue
548611 elif weight_shape [0 ] < 1024 or weight_shape [1 ] < 1024 :
549612 continue
550613 fqn_to_config_dict [fqn ] = config_obj
551614 fqn_to_config = FqnToConfig (fqn_to_config = fqn_to_config_dict )
552615
553616 # Quantize the main component using this config
554617 quantize_ (pipe .transformer , fqn_to_config , filter_fn = None )
618+
555619 if use_compile :
556- pipe .transformer = torch .compile (pipe .transformer , mode = torch_compile_mode )
620+ apply_torch_compile (pipe , torch_compile_mode )
621+
557622 if print_model :
558623 print_pipeline_architecture (pipe )
559624
@@ -615,13 +680,23 @@ def run(
615680 if local_rank == 0 :
616681 # warm up compile
617682 _ = generate_image (
618- pipe , prompts_to_use [0 ], RANDOM_SEED , device , num_inference_steps
683+ pipe ,
684+ prompts_to_use [0 ],
685+ RANDOM_SEED ,
686+ device ,
687+ num_inference_steps ,
688+ batch_size = batch_size ,
619689 )
620690
621691 for _ in range (perf_n_iter ):
622692 t0 = time .time ()
623693 _ = generate_image (
624- pipe , prompts_to_use [0 ], RANDOM_SEED , device , num_inference_steps
694+ pipe ,
695+ prompts_to_use [0 ],
696+ RANDOM_SEED ,
697+ device ,
698+ num_inference_steps ,
699+ batch_size = batch_size ,
625700 )
626701 t1 = time .time ()
627702 times .append (t1 - t0 )
@@ -691,11 +766,13 @@ def run(
691766 writer .writerow (["average_quantized_time" , f"{ avg_quant_time :.4f} " ])
692767 elif mode == "performance_hp" :
693768 writer .writerow (["perf_n_iter" , perf_n_iter ])
769+ writer .writerow (["batch_size" , batch_size ])
694770 writer .writerow (["average_time" , f"{ avg_time :.4f} " ])
695771 for idx , val in enumerate (baseline_times ):
696772 writer .writerow ([f"time_{ idx } " , f"{ val :.4f} " ])
697773 elif mode == "performance_quant" :
698774 writer .writerow (["perf_n_iter" , perf_n_iter ])
775+ writer .writerow (["batch_size" , batch_size ])
699776 writer .writerow (["average_time" , f"{ avg_time :.4f} " ])
700777 for idx , val in enumerate (times ):
701778 writer .writerow ([f"time_{ idx } " , f"{ val :.4f} " ])
0 commit comments