hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows#4031
hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows#4031
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4031
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit d20435e with merge base e03f787 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: c0ed8c1 ghstack-comment-id: 4027115619 Pull-Request: #4031
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: 80ddcef ghstack-comment-id: 4027115619 Pull-Request: #4031
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: fc673b2 ghstack-comment-id: 4027115619 Pull-Request: #4031
|
still seeing e2e regressions on diffusion models, need to chase it down |
|
@claude review this pr, determine if there are any flaws in benchmarking methodology |
|
Claude finished @drisspg's task in 3m 1s —— View job Review: Benchmarking Methodology Analysis
Benchmarking Methodology AssessmentThe PR reports results from The benchmark itself is soundThe
No issues with the benchmark comparison fairnessThe before/after comparison is apples-to-apples: same benchmark script, same recipe ( One minor observation on presented resultsThe PR description shows two runs before and two runs after. The variance between the two "before" runs and between the two "after" runs is notable (e.g., Code Review FindingsScale convention change is mathematically consistentI traced through the old vs new scale convention carefully:
All five sites are updated consistently — the reciprocal relationship is maintained throughout. The net mathematical effect on dequantized values and matmul results is identical. The MSLK wrapper docstring has a minor inconsistency
|
|
update: the e2e slowdowns I saw were in models where the gpu was not saturated. Things look good if we turn on cuda graphs. Going to solidify nvfp4 e2e benchmarking, then rebase this PR on top so we clearly see the e2e wins. |
Summary: Combining various improvements to the flux-1.schnell benchmark in a single PR: * set `num_inference_steps` to 4 to match the default for this model * turn on cuda graphs via `reduce-overhead` to improve nvfp4 performance for batch size 1, this required patching the transformer blocks forward using code I got from jbschlosser * add a larger batch size for additional performance metrics at larger shapes * remove torch.compile from vae, as compile times are too long and the goal of this benchmark is to compare recipes and track performance improvements, not achieve best possible latency. Now perf results for a single configuration take ~30s - ~1 min, and overall run on one B200 is ~20 mins. * fix quant recipe to exclude one more layer with small shapes * add the results to documentation I want to land this so we can see the e2e metrics lift from eventually landing #4031 Test Plan: ``` ./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh ``` ghstack-source-id: 8fb74eb ghstack-comment-id: 4055040194 Pull-Request: #4072
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: 4906a07 ghstack-comment-id: 4027115619 Pull-Request: #4031
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: 3a7ee6c ghstack-comment-id: 4027115619 Pull-Request: #4031
|
Are there benchmark runs that report memory throughput? |
we don't have it in torchao right now, but I validated with https://github.com/sayakpaul/diffusers-blackwell-quants that peak GPU memory does not change with this PR |
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: c9e2e5b ghstack-comment-id: 4027115619 Pull-Request: #4031
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: c9e2e5b ghstack-comment-id: 4027115619 Pull-Request: #4031
Summary: same as `run_all_benchmarks.sh`, but with `reduce-overhead` and for a local machine Test Plan: ``` // full run // note: this run used a local torchao build with pytorch/ao#4031 time HF_HUB_DISABLE_PROGRESS_BARS=1 ./run_all_benchmarks_local.sh 2>&1 | tee ~/tmp/20260313_diffusers_full_sweep_logs_mslk.tx // output: https://gist.github.com/vkuzo/40ee0268a590e270900a2538055b13f0 ```
Summary: same as `run_all_benchmarks.sh`, but with `reduce-overhead` and for a local machine Test Plan: ``` // full run // note: this run used a local torchao build with pytorch/ao#4031 time HF_HUB_DISABLE_PROGRESS_BARS=1 ./run_all_benchmarks_local.sh 2>&1 | tee ~/tmp/20260313_diffusers_full_sweep_logs_mslk.tx // output: https://gist.github.com/vkuzo/40ee0268a590e270900a2538055b13f0 ```
Summary:
Changes NVFP4 inference triton kernel to use
mslkinstead of the one checked in to torchao. Note thatmslk(an optional dependency) is now required for the default usage ofNVFP4DynamicActivationNVFP4WeightConfig.We can delete the torchao's nvfp4 kernel in a future PR, to keep this one small.
Currently
torchaodefines the nvfp4 global scale asamax / (448 * 6), and mslk (and flashinfer) define it as(448 * 6) / amax. In a future PR we should consider moving torchao to the mslk definition. Saving that for a future PR as that will be a BC-breaking change which will need careful testing of existing checkpoints.Test Plan:
microbenchmark sweep before
and after
e2e model accuracy