Skip to content

Commit 44788cd

Browse files
committed
hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: 3a7ee6c ghstack-comment-id: 4027115619 Pull-Request: #4031
1 parent e03f787 commit 44788cd

6 files changed

Lines changed: 111 additions & 26 deletions

File tree

benchmarks/mx_formats/cast_bench.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp):
8383

8484

8585
def to_nvfp4_reference_triton_swizzle(x_hp):
86+
per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device)
8687
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
87-
x_hp, use_triton_kernel=True, is_swizzled_scales=True
88+
x_hp,
89+
per_tensor_scale=per_tensor_scale,
90+
use_triton_kernel=True,
91+
is_swizzled_scales=True,
8892
)
8993
return nvfp4_tensor.qdata, nvfp4_tensor.scale
9094

docs/source/workflows/inference.md

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,31 +151,32 @@ torchao version 0.17.0+git3075bb624
151151
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True
152152
...
153153
GPU NVIDIA B200
154-
torch version 2.12.0.dev20260218+cu130
155-
torchao version 0.17.0+git3075bb624
154+
torch version 2.12.0.dev20260312+cu130
155+
torchao version 0.17.0+gitbd7717d20
156156
...
157157
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
158-
0 1024 1024 1024 1.00 0.28
159-
1 2048 2048 2048 2.36 0.52
160-
2 4096 4096 4096 2.89 0.90
161-
3 8192 8192 8192 3.32 1.41
162-
4 16384 16384 16384 3.62 2.14
158+
0 1024 1024 1024 1.00 0.46
159+
1 2048 2048 2048 2.36 0.76
160+
2 4096 4096 4096 2.89 1.37
161+
3 8192 8192 8192 3.32 1.97
162+
4 16384 16384 16384 3.62 2.77
163163

164164
#
165165
# nvfp4 with static global scaling (user API in progress)
166166
#
167167
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
168168
...
169169
GPU NVIDIA B200
170-
torch version 2.12.0.dev20260218+cu130
171-
torchao version 0.17.0+git3075bb624
170+
torch version 2.12.0.dev20260312+cu130
171+
torchao version 0.17.0+gitbd7717d20
172172
...
173173
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
174-
0 1024 1024 1024 1.00 0.34
175-
1 2048 2048 2048 2.74 0.64
176-
2 4096 4096 4096 3.42 1.06
177-
3 8192 8192 8192 3.67 1.58
178-
4 16384 16384 16384 3.82 2.31
174+
0 1024 1024 1024 1.00 0.55
175+
1 2048 2048 2048 2.74 0.95
176+
2 4096 4096 4096 3.42 1.69
177+
3 8192 8192 8192 3.67 2.29
178+
4 16384 16384 16384 3.82 2.98
179+
179180
```
180181

181182
## e2e flux-1.schnell benchmarks
@@ -198,7 +199,7 @@ high level, and measure performance improvements.
198199
| bfloat16 | 0 | 0.4178 | 1.00 | 1.4914 | 1.00 |
199200
| float8_rowwise | 0.1236| 0.3455 | 1.21 | 1.1986 | 1.24 |
200201
| mxfp8 | 0.1260 | 0.3673 | 1.14 | 1.2820 | 1.16 |
201-
| nvfp4 | 0.2694 | 0.3308 | 1.26 | 1.1334 | 1.32 |
202+
| nvfp4 | 0.2694 | 0.3203 | 1.30 | 1.0913 | 1.37 |
202203

203204
To reproduce, run:
204205

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_inference_workflow_nvfp4(
217217
y_ref = m(x)
218218

219219
if use_triton_kernel and quant_type == "dynamic":
220-
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
220+
with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result:
221221
y_mx = m_mx(x)
222222
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
223223
else:

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def test_nvfp4_swizzled_scales_get_scales_method():
369369
@torch.no_grad()
370370
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
371371
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+
if not use_per_tensor_scale:
373+
pytest.skip("MSLK triton kernel requires per_tensor_scale")
372374

373375
torch.manual_seed(42)
374376
x = torch.randn(M, N, dtype=dtype, device="cuda")
@@ -392,8 +394,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
392394
)
393395

394396
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
395-
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
396-
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
397+
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
398+
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
397399
torch.testing.assert_close(
398400
pt_unpacked,
399401
triton_unpacked,
@@ -559,8 +561,14 @@ def test_scale_shape_matches_qdata(
559561
block_size = 16
560562

561563
x_hp = torch.randn(*shape, device="cuda")
564+
565+
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))
566+
562567
x = NVFP4Tensor.to_nvfp4(
563-
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
568+
x_hp,
569+
per_tensor_scale=per_tensor_scale,
570+
is_swizzled_scales=is_swizzled_scales,
571+
use_triton_kernel=use_triton_kernel,
564572
)
565573

566574
if len(shape) == 2:

torchao/prototype/mx_formats/kernels.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,3 +1387,72 @@ def mxfp8_quantize_cuda(
13871387
raise NotImplementedError(
13881388
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
13891389
)
1390+
1391+
1392+
try:
1393+
from mslk.quantize.triton.fp4_quantize import (
1394+
triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4,
1395+
)
1396+
1397+
_mslk_available = True
1398+
except ImportError:
1399+
_mslk_available = False
1400+
1401+
1402+
def mslk_quantize_nvfp4(
1403+
x: torch.Tensor, per_tensor_scale: torch.Tensor
1404+
) -> Tuple[torch.Tensor, torch.Tensor]:
1405+
"""Quantize a tensor to NVFP4 using the MSLK triton kernel.
1406+
1407+
Args:
1408+
x: Input tensor to quantize.
1409+
per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)).
1410+
1411+
Returns:
1412+
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
1413+
"""
1414+
mslk_global_scale = 1.0 / per_tensor_scale
1415+
return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale)
1416+
1417+
1418+
@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=())
1419+
def _mslk_quantize_nvfp4_custom_op(
1420+
x: torch.Tensor, global_scale: torch.Tensor
1421+
) -> Tuple[torch.Tensor, torch.Tensor]:
1422+
"""Inner custom op for MSLK NVFP4 quantization.
1423+
1424+
Args:
1425+
x: Input tensor to quantize.
1426+
global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale).
1427+
1428+
Returns:
1429+
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
1430+
"""
1431+
assert _mslk_available, (
1432+
"mslk is required for NVFP4 triton quantization. "
1433+
"Install from https://github.com/pytorch/MSLK"
1434+
)
1435+
data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale)
1436+
return blockwise_scales, data_lp.view(torch.uint8)
1437+
1438+
1439+
@_mslk_quantize_nvfp4_custom_op.register_fake
1440+
def _(x, global_scale):
1441+
# Mirror the reshape logic from the real MSLK kernel
1442+
orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1]
1443+
x_2d = x.reshape(-1, orig_N)
1444+
M, N = x_2d.shape
1445+
1446+
num_scales = N // 16
1447+
n_row_blocks = triton.cdiv(M, 128)
1448+
n_col_blocks = triton.cdiv(num_scales, 4)
1449+
padded_rows = n_row_blocks * 128
1450+
padded_cols = n_col_blocks * 4
1451+
1452+
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
1453+
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
1454+
1455+
# Reshape back to match original leading dims
1456+
scales = scales.view(*orig_leading_dims, -1, padded_cols)
1457+
xq = xq.view(*orig_leading_dims, -1, N // 2)
1458+
return scales, xq

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchao.prototype.mx_formats.kernels import (
1616
f4_unpacked_to_f32,
1717
f32_to_f4_unpacked,
18+
mslk_quantize_nvfp4,
1819
pack_uint4,
19-
triton_quantize_nvfp4,
2020
unpack_uint4,
2121
)
2222
from torchao.prototype.mx_formats.mx_tensor import (
@@ -155,7 +155,10 @@ def to_nvfp4(
155155
assert K % 16 == 0, (
156156
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
157157
)
158-
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
158+
assert per_tensor_scale is not None, (
159+
"Triton kernel requires per_tensor_scale"
160+
)
161+
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
159162
else:
160163
blockwise_scales, data_lp = nvfp4_quantize(
161164
data_hp, block_size, per_tensor_scale
@@ -699,10 +702,10 @@ def nvfp4_quantize(
699702
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
700703
).to(torch.float8_e4m3fn)
701704
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
702-
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
703-
# To apply to data
704-
total_scale = per_tensor_scale * scaled_block_scales_fp32
705-
data_scaled = data_hp / total_scale.unsqueeze(-1)
705+
# Multiply by reciprocal of combined scale instead of dividing,
706+
# to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale)
707+
reciprocal_scale = (1.0 / per_tensor_scale) / scaled_block_scales_fp32
708+
data_scaled = data_hp * reciprocal_scale.unsqueeze(-1)
706709
out_scales = scaled_block_scales_fp8
707710

708711
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)

0 commit comments

Comments
 (0)