5858)
5959from torchao .quantization .quantize_ .common import KernelPreference
6060from torchao .testing .training .roofline_utils import (
61+ get_inference_bf16_activation_mem_sympy ,
6162 get_inference_float8_mem_sympy ,
6263 get_inference_gemm_time_sympy ,
6364)
@@ -111,7 +112,7 @@ def get_gemm_times(
111112
112113 bf16_time_s = get_gpu_kernel_gemm_time_s (torch .mm , x_bf16 , w_bf16 )
113114
114- if recipe_name in ("mxfp4_cutlass" , "nvfp4" ):
115+ if recipe_name in ("mxfp4_cutlass" , "nvfp4" , "nvfp4_static" ):
115116 d1 , d2 , d3 = torch .float4_e2m1fn_x2 , torch .float4_e2m1fn_x2 , torch .bfloat16
116117 A = torch .randint (0 , 255 , (M , K // 2 ), device = device , dtype = torch .uint8 ).view (
117118 d1
@@ -150,7 +151,7 @@ def get_gemm_times(
150151 scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
151152 scale_a = to_blocked (scale_a )
152153 scale_b = to_blocked (scale_b )
153- elif recipe_name == "nvfp4" :
154+ elif recipe_name in ( "nvfp4" , "nvfp4_static" ) :
154155 scale_a = torch .ones (M , K // 16 , device = device , dtype = torch .float8_e4m3fn )
155156 scale_b = torch .ones (N , K // 16 , device = device , dtype = torch .float8_e4m3fn )
156157 scale_a = to_blocked (scale_a )
@@ -176,7 +177,7 @@ def do_matmul(A, B):
176177 swizzle_b = SwizzleType .SWIZZLE_32_4_4 ,
177178 output_dtype = d3 ,
178179 )
179- if recipe_name == "nvfp4" :
180+ if recipe_name in ( "nvfp4" , "nvfp4_static" ) :
180181 return torch ._scaled_mm (
181182 A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = False
182183 )
@@ -468,8 +469,8 @@ def _stack_layers_conv(
468469
469470
470471def run (
471- outfile : str ,
472472 recipe_name : str ,
473+ outfile : str | None = None ,
473474 do_benchmarks : bool = True ,
474475 shape_gen_name : str = "pow2" ,
475476 M : Optional [int ] = None ,
@@ -485,6 +486,7 @@ def run(
485486 kernel_size : Optional [int ] = None ,
486487 stride : int = 1 ,
487488 padding : int = 0 ,
489+ skip_printing_detailed_metrics : bool = False ,
488490):
489491 """
490492 Args:
@@ -500,6 +502,8 @@ def run(
500502 * `kernel_size`: kernel_size for conv3d / conv2d
501503 * `stride`: stride for conv ops (default: 1)
502504 * `padding`: padding for conv ops (default: 0)
505+ * `skip_printing_detailed_metrics`: if True, prints e2e roofline
506+ and observed speedups only, skipping all other intermediate metrics
503507 """
504508 _SUPPORTED_OPS = ["linear" , "conv2d" , "conv3d" ]
505509 assert op_name in _SUPPORTED_OPS , (
@@ -561,6 +565,11 @@ def run(
561565 # TODO(future): also enable fusion modeling here
562566 )
563567 bf16_gemm_time_sympy = get_inference_gemm_time_sympy (M , K , N , torch .bfloat16 , None )
568+ if enable_fusion_modeling and op_name == "linear" :
569+ bf16_ovhd_time_sympy = get_inference_bf16_activation_mem_sympy (M , K , N )
570+ else :
571+ # multiply by M to ensure we get a sympy symbol
572+ bf16_ovhd_time_sympy = M * 0
564573
565574 if recipe_name and recipe_name .startswith (("nvfp4" , "mxfp4" )):
566575 fp8_gemm_time_sympy = get_inference_gemm_time_sympy (
@@ -572,6 +581,7 @@ def run(
572581 M , K , N , torch .float8_e4m3fn , gemm_recipe_name
573582 )
574583 print ("bf16_gemm_time_sympy" , bf16_gemm_time_sympy )
584+ print ("bf16_ovhd_time_sympy" , bf16_ovhd_time_sympy )
575585 print ("fp8_gemm_time_sympy" , fp8_gemm_time_sympy )
576586 print ("fp8_ovhd_time_sympy" , fp8_ovhd_time_sympy )
577587 print ()
@@ -587,6 +597,8 @@ def run(
587597 # roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims)
588598 "r_bf16_gemm_s" ,
589599 "r_fp8_gemm_s" ,
600+ # roofline - bf16 overhead time (read-write prev activation, only if fusion modeling is on)
601+ "r_bf16_ovhd_s" ,
590602 # roofline - fp8 overhead time (by counting reads/writes in the ideal case)
591603 "r_fp8_ovhd_s" ,
592604 # roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
@@ -628,11 +640,16 @@ def run(
628640 )
629641
630642 # note: cast from sympy.core.numbers.Float to float to make pandas formatting work
643+ r_bf16_ovhd_time_s = float (
644+ bf16_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
645+ )
631646 r_fp8_ovhd_time_s = float (
632647 fp8_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
633648 )
634649 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
635- r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
650+ r_speedup = (r_bf16_gemm_time_s + r_bf16_ovhd_time_s ) / (
651+ r_fp8_gemm_time_s + r_fp8_ovhd_time_s
652+ )
636653
637654 # if enabled, also measured observed gemm time
638655 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
@@ -679,11 +696,16 @@ def run(
679696 r_fp8_gemm_time_s = float (
680697 fp8_gemm_time_sympy .subs (M , gemm_M ).subs (K , gemm_K ).subs (N , gemm_N )
681698 )
699+ r_bf16_ovhd_time_s = float (
700+ bf16_ovhd_time_sympy .subs (M , M_val ).subs (K , K_val ).subs (N , N_val )
701+ )
682702 r_fp8_ovhd_time_s = float (
683703 fp8_ovhd_time_sympy .subs (M , gemm_M ).subs (K , gemm_K ).subs (N , gemm_N )
684704 )
685705 r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
686- r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s )
706+ r_speedup = (r_bf16_gemm_time_s + r_bf16_ovhd_time_s ) / (
707+ r_fp8_gemm_time_s + r_fp8_ovhd_time_s
708+ )
687709
688710 # measure actual conv kernel times (without quant overhead)
689711 b_bf16_gemm_time_s , b_fp8_gemm_time_s = 0 , 0
@@ -773,12 +795,29 @@ def run(
773795 )
774796 elif recipe_name == "nvfp4" :
775797 config = NVFP4DynamicActivationNVFP4WeightConfig (
776- use_dynamic_per_tensor_scale = False ,
798+ use_dynamic_per_tensor_scale = True ,
799+ )
800+ elif recipe_name == "nvfp4_static" :
801+ config_calib = NVFP4DynamicActivationNVFP4WeightConfig (
802+ step = "prepare" ,
803+ )
804+ config = NVFP4DynamicActivationNVFP4WeightConfig (
805+ step = "convert" ,
777806 )
778807 else :
779808 assert False , "unsupported"
780809
781810 m_fp8_dyn = copy .deepcopy (m_orig )
811+
812+ if recipe_name == "nvfp4_static" :
813+ # calibrate with sample data
814+ # this benchmark is performance-only, so a toy datum is fine
815+ quantize_ (m_fp8_dyn , config_calib )
816+ toy_datum = torch .randn (
817+ M_val , K_val , dtype = torch .bfloat16 , device = "cuda"
818+ )
819+ m_fp8_dyn (toy_datum )
820+
782821 if op_name == "linear" :
783822 quantize_ (m_fp8_dyn , config )
784823 elif op_name == "conv2d" :
@@ -813,7 +852,8 @@ def run(
813852 # roofline - gemm
814853 r_bf16_gemm_time_s ,
815854 r_fp8_gemm_time_s ,
816- # roofline - fp8 overhead
855+ # roofline - overhead
856+ r_bf16_ovhd_time_s ,
817857 r_fp8_ovhd_time_s ,
818858 # roofline - gemm + overhead, and speedup
819859 r_fp8_gemm_and_ovhd_s ,
@@ -833,8 +873,20 @@ def run(
833873
834874 pd .set_option ("display.precision" , 2 )
835875 df = pd .DataFrame (results , columns = headers )
876+
877+ if outfile is not None :
878+ df .to_csv (outfile )
879+
880+ if op_name == "linear" :
881+ # drop conv-only columns to simplify linear results
882+ df = df .drop (columns = ["D" , "H" , "W" , "kernel_size" ])
883+
884+ if skip_printing_detailed_metrics :
885+ df = df [
886+ ["fwd_M" , "fwd_K" , "fwd_N" , "r_fp8_gemm_and_ovhd_spdp" , "b_fp8_e2e_spdp" ]
887+ ]
888+
836889 print (df )
837- df .to_csv (outfile )
838890 print ("done" )
839891
840892
0 commit comments