Commit a302c10
authored
[nvfp4] Make per_tensor_scale optional for triton kernel path (#4188)
* [nvfp4] Make per_tensor_scale optional for triton kernel path
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).
Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
operands (treat None as 1.0) instead of asserting both-or-neither
Test Plan:
Requires SM100+ GPU with MSLK nightly installed.
```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```
Performance:
with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True
Parameter Value
---------------------- ------------------------
GPU NVIDIA GB200
torch version 2.12.0.dev20260316+cu128
torchao version 0.17.0+git95281b63b
recipe_name nvfp4
do_benchmarks True
shape_gen_name pow2
enable_fusion_modeling True
op_name linear
MKN None None None
DHW None None None
kernel_size
stride 1
padding 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.45
1 2048 2048 2048 2.39 0.66
2 4096 4096 4096 2.92 1.29
3 8192 8192 8192 3.34 1.74
4 16384 16384 16384 3.63 2.84
```
without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True
Parameter Value
---------------------- ------------------------
GPU NVIDIA GB200
torch version 2.12.0.dev20260316+cu128
torchao version 0.17.0+gitabb103d3b
recipe_name nvfp4_no_global_scale
do_benchmarks True
shape_gen_name pow2
enable_fusion_modeling True
op_name linear
MKN None None None
DHW None None None
kernel_size
stride 1
padding 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.73
1 2048 2048 2048 2.71 1.09
2 4096 4096 4096 3.44 2.22
3 8192 8192 8192 3.68 2.82
4 16384 16384 16384 3.83 3.65
```
[ghstack-poisoned]1 parent b1ddd15 commit a302c10
7 files changed
Lines changed: 107 additions & 24 deletions
File tree
- benchmarks/float8
- test/prototype/mx_formats
- torchao
- prototype/mx_formats
- testing/training
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
112 | 112 | | |
113 | 113 | | |
114 | 114 | | |
115 | | - | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
116 | 121 | | |
117 | 122 | | |
118 | 123 | | |
| |||
151 | 156 | | |
152 | 157 | | |
153 | 158 | | |
154 | | - | |
| 159 | + | |
155 | 160 | | |
156 | 161 | | |
157 | 162 | | |
| |||
177 | 182 | | |
178 | 183 | | |
179 | 184 | | |
180 | | - | |
| 185 | + | |
181 | 186 | | |
182 | 187 | | |
183 | 188 | | |
| |||
797 | 802 | | |
798 | 803 | | |
799 | 804 | | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
800 | 809 | | |
801 | 810 | | |
802 | 811 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
184 | | - | |
185 | | - | |
186 | 184 | | |
187 | 185 | | |
188 | 186 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
369 | 369 | | |
370 | 370 | | |
371 | 371 | | |
372 | | - | |
373 | | - | |
374 | | - | |
375 | 372 | | |
376 | 373 | | |
377 | 374 | | |
| |||
657 | 654 | | |
658 | 655 | | |
659 | 656 | | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
1175 | 1176 | | |
1176 | 1177 | | |
1177 | 1178 | | |
1178 | | - | |
| 1179 | + | |
1179 | 1180 | | |
1180 | 1181 | | |
1181 | 1182 | | |
1182 | 1183 | | |
1183 | 1184 | | |
1184 | | - | |
| 1185 | + | |
| 1186 | + | |
1185 | 1187 | | |
1186 | 1188 | | |
1187 | 1189 | | |
1188 | 1190 | | |
1189 | | - | |
| 1191 | + | |
| 1192 | + | |
| 1193 | + | |
1190 | 1194 | | |
1191 | 1195 | | |
1192 | 1196 | | |
1193 | 1197 | | |
1194 | 1198 | | |
1195 | | - | |
| 1199 | + | |
1196 | 1200 | | |
1197 | 1201 | | |
1198 | 1202 | | |
1199 | 1203 | | |
1200 | 1204 | | |
1201 | | - | |
| 1205 | + | |
| 1206 | + | |
1202 | 1207 | | |
1203 | 1208 | | |
1204 | 1209 | | |
| |||
1211 | 1216 | | |
1212 | 1217 | | |
1213 | 1218 | | |
| 1219 | + | |
| 1220 | + | |
| 1221 | + | |
| 1222 | + | |
| 1223 | + | |
| 1224 | + | |
1214 | 1225 | | |
1215 | 1226 | | |
1216 | 1227 | | |
1217 | 1228 | | |
1218 | 1229 | | |
1219 | | - | |
| 1230 | + | |
1220 | 1231 | | |
1221 | 1232 | | |
1222 | 1233 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
155 | 155 | | |
156 | 156 | | |
157 | 157 | | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | 158 | | |
162 | 159 | | |
163 | 160 | | |
| |||
497 | 494 | | |
498 | 495 | | |
499 | 496 | | |
500 | | - | |
501 | | - | |
502 | | - | |
| 497 | + | |
| 498 | + | |
| 499 | + | |
| 500 | + | |
| 501 | + | |
| 502 | + | |
| 503 | + | |
503 | 504 | | |
504 | | - | |
505 | 505 | | |
506 | 506 | | |
507 | 507 | | |
| |||
720 | 720 | | |
721 | 721 | | |
722 | 722 | | |
723 | | - | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
724 | 726 | | |
725 | 727 | | |
726 | 728 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
519 | 519 | | |
520 | 520 | | |
521 | 521 | | |
522 | | - | |
| 522 | + | |
523 | 523 | | |
524 | 524 | | |
525 | 525 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1277 | 1277 | | |
1278 | 1278 | | |
1279 | 1279 | | |
| 1280 | + | |
| 1281 | + | |
| 1282 | + | |
| 1283 | + | |
| 1284 | + | |
| 1285 | + | |
| 1286 | + | |
| 1287 | + | |
1280 | 1288 | | |
1281 | 1289 | | |
1282 | 1290 | | |
| |||
0 commit comments