Skip to content

hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows#4031

Merged
vkuzo merged 8 commits intomainfrom
gh/vkuzo/233/head
Mar 13, 2026
Merged

hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows#4031
vkuzo merged 8 commits intomainfrom
gh/vkuzo/233/head

Conversation

@vkuzo
Copy link
Copy Markdown
Contributor

@vkuzo vkuzo commented Mar 9, 2026

Summary:

Changes NVFP4 inference triton kernel to use mslk instead of the one checked in to torchao. Note that mslk (an optional dependency) is now required for the default usage of NVFP4DynamicActivationNVFP4WeightConfig.

  • microbenchmark speedups - large
    • for example dynamic + MKN 4096 speedup increases from 0.90x to 1.27x. See full table below for more.
  • e2e speedups on flux-1.schnell - medium
    • 1.26x -> 1.30x at bsz 1
    • 1.32x -> 1.37x at bsz 4
  • e2e accuracy - no change

We can delete the torchao's nvfp4 kernel in a future PR, to keep this one small.

Currently torchao defines the nvfp4 global scale as amax / (448 * 6), and mslk (and flashinfer) define it as (448 * 6) / amax. In a future PR we should consider moving torchao to the mslk definition. Saving that for a future PR as that will be a BC-breaking change which will need careful testing of existing checkpoints.

Test Plan:

microbenchmark sweep before

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

0   1024   1024   1024                      1.00            0.28
1   2048   2048   2048                      2.36            0.52
2   4096   4096   4096                      2.89            0.90
3   8192   8192   8192                      3.32            1.41
4  16384  16384  16384                      3.62            2.14

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.34
1   2048   2048   2048                      2.74            0.64
2   4096   4096   4096                      3.42            1.06
3   8192   8192   8192                      3.67            1.58
4  16384  16384  16384                      3.82            2.31

and after

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

e2e model accuracy

./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh

vkuzo added 3 commits March 9, 2026 20:40
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 9, 2026

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4031

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit d20435e with merge base e03f787 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

vkuzo added a commit that referenced this pull request Mar 9, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: c0ed8c1
ghstack-comment-id: 4027115619
Pull-Request: #4031
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 9, 2026
Comment thread test/prototype/mx_formats/test_nvfp4_tensor.py Outdated
@vkuzo vkuzo added the module: inference quantize_ api inference flow label Mar 9, 2026
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 10, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: 80ddcef
ghstack-comment-id: 4027115619
Pull-Request: #4031
@vkuzo vkuzo changed the base branch from gh/vkuzo/232/head to main March 10, 2026 08:27
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 10, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: fc673b2
ghstack-comment-id: 4027115619
Pull-Request: #4031
@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 10, 2026

still seeing e2e regressions on diffusion models, need to chase it down

@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 10, 2026

@claude review this pr, determine if there are any flaws in benchmarking methodology

@claude
Copy link
Copy Markdown

claude Bot commented Mar 10, 2026

Claude finished @drisspg's task in 3m 1s —— View job


Review: Benchmarking Methodology Analysis

  • Read changed files and understand the kernel integration
  • Analyze benchmarking code and methodology
  • Check for correctness of performance comparisons
  • Post review findings

Benchmarking Methodology Assessment

The PR reports results from benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static. I reviewed the benchmark script and the code changes in detail. Here are my findings:

The benchmark itself is sound

The float8_inference_roofline.py benchmark methodology is reasonable:

  • Warmup: 2 warmup iterations before profiling (float8_inference_roofline.py:74-76)
  • Timing: 5 profiled iterations with torch.cuda.synchronize(), averaged (float8_inference_roofline.py:78-96)
  • Fair comparison: Both BF16 and NVFP4 paths go through torch.compile on deepcopy'd models
  • Roofline vs e2e separation: The r_fp8_gemm_and_ovhd_spdp column is a roofline model estimate (unchanged by this PR, which is expected since the GEMM kernel itself hasn't changed), while b_fp8_e2e_spdp is the actual measured end-to-end speedup. The improvement in b_fp8_e2e_spdp reflects the faster MSLK quantization kernel reducing overhead in the e2e path.

No issues with the benchmark comparison fairness

The before/after comparison is apples-to-apples: same benchmark script, same recipe (nvfp4_static), same model structure (LN → Linear → Sigmoid with enable_fusion_modeling=True). The only change is the quantization kernel used internally. The roofline column (r_fp8_gemm_and_ovhd_spdp) stays identical before/after, confirming the GEMM path is unchanged — only the quantization overhead improved.

One minor observation on presented results

The PR description shows two runs before and two runs after. The variance between the two "before" runs and between the two "after" runs is notable (e.g., b_fp8_e2e_spdp at 16384: 2.14 vs 2.31 before, 2.73 vs 2.98 after — roughly 8% variance). This is typical for GPU microbenchmarks, but the improvement is large enough (30-40% e2e speedup improvement) that the signal is clearly above the noise. No methodology issue here, just a note.

Code Review Findings

Scale convention change is mathematically consistent

I traced through the old vs new scale convention carefully:

Old New (MSLK)
per_tensor_amax_to_scale(amax) amax / (F8E4M3_MAX * F4_E2M1_MAX) (F8E4M3_MAX * F4_E2M1_MAX) / amax
nvfp4_quantize block scale block_scale / per_tensor_scale block_scale * per_tensor_scale
nvfp4_quantize data data / (pts * fp8_scale) data * (pts / fp8_scale)
get_hp_scales dequant pts * scale_e4m3 scale_e4m3 / pts
_addmm_nvfp4_dispatch a.pts * b.pts 1 / (a.pts * b.pts)

All five sites are updated consistently — the reciprocal relationship is maintained throughout. The net mathematical effect on dequantized values and matmul results is identical.

The MSLK wrapper docstring has a minor inconsistency

kernels.py:1350 says the global_scale parameter is in "MSLK convention (1.0 / per_tensor_scale)". But after this PR, per_tensor_scale itself IS in MSLK convention (i.e., (F8E4M3_MAX * F4_E2M1_MAX) / amax). So the wrapper passes per_tensor_scale directly to MSLK at nvfp4_tensor.py:161, not its reciprocal. The docstring should say something like "Global scale in MSLK convention" without "1.0 / per_tensor_scale" since the per_tensor_scale variable is already in MSLK convention after this PR.

Fix this →

cast_bench.py hardcodes per_tensor_scale=1.0

In cast_bench.py:86, the triton swizzle benchmark now hardcodes per_tensor_scale = torch.tensor(1.0). This is necessary because the MSLK kernel requires per_tensor_scale, but 1.0 in MSLK convention means amax = F8E4M3_MAX * F4_E2M1_MAX = 2688, which is a specific (not no-op) calibration point. This is fine for a cast/quantization microbenchmark since it only measures kernel throughput, but worth noting it's not a "neutral" scale.

nvfp4_quantize docstring is stale

nvfp4_tensor.py:657 still refers to the parameter as per_tensor_amax in the Args section, but the actual parameter name is per_tensor_scale and its semantics changed with this PR.

The author's self-identified TODO is the biggest concern

The comment at the bottom of the PR description ("TODO verify e2e model accuracy") and the author's follow-up comment ("still seeing e2e regressions on diffusion models, need to chase it down") are worth flagging. The benchmark only measures performance, not numerical accuracy. The x * (global_scale / fp8_scale) formulation (multiply-then-divide) can differ from x / (pts * fp8_scale) (single division) due to floating-point non-associativity, which could cause subtle accuracy regressions in sensitive models. This isn't a benchmarking methodology flaw per se, but it means the benchmark numbers alone are insufficient to validate this change.

Summary

No flaws in benchmarking methodology. The benchmark is well-structured with proper warmup, fair comparison, and the performance signal is well above noise. The code changes are mathematically consistent. The main risk is the numerical accuracy concern the author has already identified — the benchmarking only validates performance, not correctness.


@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 12, 2026

update: the e2e slowdowns I saw were in models where the gpu was not saturated. Things look good if we turn on cuda graphs.

Going to solidify nvfp4 e2e benchmarking, then rebase this PR on top so we clearly see the e2e wins.

vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

Combining various improvements to the flux-1.schnell benchmark in a
single PR:

* set `num_inference_steps` to 4 to match the default for this model
* turn on cuda graphs via `reduce-overhead` to improve nvfp4 performance for batch size 1,
  this required patching the transformer blocks forward using code I got
  from jbschlosser
* add a larger batch size for additional performance metrics at larger
  shapes
* remove torch.compile from vae, as compile times are too long and the
  goal of this benchmark is to compare recipes and track performance
  improvements, not achieve best possible latency. Now perf results for
  a single configuration take ~30s - ~1 min, and overall run on one B200
  is ~20 mins.
* fix quant recipe to exclude one more layer with small shapes
* add the results to documentation

I want to land this so we can see the e2e metrics lift from eventually
landing #4031

Test Plan:

```
./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh
```
ghstack-source-id: 8fb74eb
ghstack-comment-id: 4055040194
Pull-Request: #4072
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: 4906a07
ghstack-comment-id: 4027115619
Pull-Request: #4031
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: 3a7ee6c
ghstack-comment-id: 4027115619
Pull-Request: #4031
@drisspg
Copy link
Copy Markdown
Contributor

drisspg commented Mar 13, 2026

Are there benchmark runs that report memory throughput?

@vkuzo
Copy link
Copy Markdown
Contributor Author

vkuzo commented Mar 13, 2026

Are there benchmark runs that report memory throughput?

we don't have it in torchao right now, but I validated with https://github.com/sayakpaul/diffusers-blackwell-quants that peak GPU memory does not change with this PR

Comment thread torchao/prototype/mx_formats/kernels.py Outdated
Comment thread torchao/prototype/mx_formats/kernels.py Outdated
[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: c9e2e5b
ghstack-comment-id: 4027115619
Pull-Request: #4031
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

Decent speedups across the board. Note that we slightly modify the
PyTorch reference code (global scale is a reciprocal in MSLK of its
meaning in torchao) to keep bitwise equivalency between torchao
reference and MSLK's kernel.

Test Plan:

performance: wins across the board

simple microbenchmark sweep before

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
```

and after

```
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.39
1   2048   2048   2048                      2.36            0.68
2   4096   4096   4096                      2.89            1.27
3   8192   8192   8192                      3.32            1.93
4  16384  16384  16384                      3.62            2.73

> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.48
1   2048   2048   2048                      2.74            0.88
2   4096   4096   4096                      3.42            1.62
3   8192   8192   8192                      3.67            2.27
4  16384  16384  16384                      3.82            2.98

```

TODO verify e2e model accuracy
ghstack-source-id: c9e2e5b
ghstack-comment-id: 4027115619
Pull-Request: #4031
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

we moved to `mslk` in #4031,
deleting the old kernel

Separate PR to keep PRs small

Test Plan:

```
pytest test/prototype/mx_formats -s
```
ghstack-source-id: ed6aaca
ghstack-comment-id: 4057398637
Pull-Request: #4078
vkuzo added a commit to sayakpaul/diffusers-blackwell-quants that referenced this pull request Mar 13, 2026
Summary:

same as `run_all_benchmarks.sh`, but with `reduce-overhead` and for a
local machine

Test Plan:

```
// full run
// note: this run used a local torchao build with pytorch/ao#4031
time HF_HUB_DISABLE_PROGRESS_BARS=1 ./run_all_benchmarks_local.sh 2>&1 | tee ~/tmp/20260313_diffusers_full_sweep_logs_mslk.tx

// output: https://gist.github.com/vkuzo/40ee0268a590e270900a2538055b13f0
```
@vkuzo vkuzo merged commit 5c0c814 into main Mar 13, 2026
54 of 60 checks passed
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

we moved to `mslk` in #4031,
deleting the old kernel

Separate PR to keep PRs small

Test Plan:

```
pytest test/prototype/mx_formats -s
```
ghstack-source-id: c6ee2bb
ghstack-comment-id: 4057398637
Pull-Request: #4078
vkuzo added a commit that referenced this pull request Mar 13, 2026
Summary:

we moved to `mslk` in #4031,
deleting the old kernel

Separate PR to keep PRs small

Test Plan:

```
pytest test/prototype/mx_formats -s
```
ghstack-source-id: c6ee2bb
ghstack-comment-id: 4057398637
Pull-Request: #4078
sayakpaul pushed a commit to sayakpaul/diffusers-blackwell-quants that referenced this pull request Mar 14, 2026
Summary:

same as `run_all_benchmarks.sh`, but with `reduce-overhead` and for a
local machine

Test Plan:

```
// full run
// note: this run used a local torchao build with pytorch/ao#4031
time HF_HUB_DISABLE_PROGRESS_BARS=1 ./run_all_benchmarks_local.sh 2>&1 | tee ~/tmp/20260313_diffusers_full_sweep_logs_mslk.tx

// output: https://gist.github.com/vkuzo/40ee0268a590e270900a2538055b13f0
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants