Skip to content

Commit ded3400

Browse files
committed
hook up mslk's to_nvfp4 kernel to torchao's inference nvfp4 workflows
Summary: Decent speedups across the board. Note that we slightly modify the PyTorch reference code (global scale is a reciprocal in MSLK of its meaning in torchao) to keep bitwise equivalency between torchao reference and MSLK's kernel. Test Plan: performance: wins across the board simple microbenchmark sweep before ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True ``` and after ``` > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.39 1 2048 2048 2048 2.36 0.68 2 4096 4096 4096 2.89 1.27 3 8192 8192 8192 3.32 1.93 4 16384 16384 16384 3.62 2.73 > python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.48 1 2048 2048 2048 2.74 0.88 2 4096 4096 4096 3.42 1.62 3 8192 8192 8192 3.67 2.27 4 16384 16384 16384 3.82 2.98 ``` TODO verify e2e model accuracy ghstack-source-id: fc673b2 ghstack-comment-id: 4027115619 Pull-Request: #4031
1 parent 6d85949 commit ded3400

6 files changed

Lines changed: 94 additions & 26 deletions

File tree

benchmarks/mx_formats/cast_bench.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp):
8383

8484

8585
def to_nvfp4_reference_triton_swizzle(x_hp):
86+
per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device)
8687
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
87-
x_hp, use_triton_kernel=True, is_swizzled_scales=True
88+
x_hp,
89+
per_tensor_scale=per_tensor_scale,
90+
use_triton_kernel=True,
91+
is_swizzled_scales=True,
8892
)
8993
return nvfp4_tensor.qdata, nvfp4_tensor.scale
9094

docs/source/workflows/inference.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,11 @@ torch version 2.12.0.dev20260218+cu130
155155
torchao version 0.17.0+git3075bb624
156156
...
157157
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
158-
0 1024 1024 1024 1.00 0.28
159-
1 2048 2048 2048 2.36 0.52
160-
2 4096 4096 4096 2.89 0.90
161-
3 8192 8192 8192 3.32 1.41
162-
4 16384 16384 16384 3.62 2.14
158+
0 1024 1024 1024 1.00 0.39
159+
1 2048 2048 2048 2.36 0.68
160+
2 4096 4096 4096 2.89 1.27
161+
3 8192 8192 8192 3.32 1.93
162+
4 16384 16384 16384 3.62 2.73
163163

164164
#
165165
# nvfp4 with static global scaling (user API in progress)
@@ -171,11 +171,11 @@ torch version 2.12.0.dev20260218+cu130
171171
torchao version 0.17.0+git3075bb624
172172
...
173173
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
174-
0 1024 1024 1024 1.00 0.34
175-
1 2048 2048 2048 2.74 0.64
176-
2 4096 4096 4096 3.42 1.06
177-
3 8192 8192 8192 3.67 1.58
178-
4 16384 16384 16384 3.82 2.31
174+
0 1024 1024 1024 1.00 0.48
175+
1 2048 2048 2048 2.74 0.88
176+
2 4096 4096 4096 3.42 1.62
177+
3 8192 8192 8192 3.67 2.27
178+
4 16384 16384 16384 3.82 2.98
179179
```
180180

181181
## Other Available Quantization Techniques

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def test_inference_workflow_nvfp4(
217217
y_ref = m(x)
218218

219219
if use_triton_kernel and quant_type == "dynamic":
220-
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
220+
with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result:
221221
y_mx = m_mx(x)
222222
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
223223
else:

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ def test_nvfp4_swizzled_scales_get_scales_method():
369369
@torch.no_grad()
370370
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
371371
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+
if not use_per_tensor_scale:
373+
pytest.skip("MSLK triton kernel requires per_tensor_scale")
372374

373375
torch.manual_seed(42)
374376
x = torch.randn(M, N, dtype=dtype, device="cuda")
@@ -559,8 +561,14 @@ def test_scale_shape_matches_qdata(
559561
block_size = 16
560562

561563
x_hp = torch.randn(*shape, device="cuda")
564+
565+
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))
566+
562567
x = NVFP4Tensor.to_nvfp4(
563-
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
568+
x_hp,
569+
per_tensor_scale=per_tensor_scale,
570+
is_swizzled_scales=is_swizzled_scales,
571+
use_triton_kernel=use_triton_kernel,
564572
)
565573

566574
if len(shape) == 2:

torchao/prototype/mx_formats/kernels.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,3 +1327,56 @@ def mxfp8_quantize_cuda(
13271327
raise NotImplementedError(
13281328
"`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details."
13291329
)
1330+
1331+
1332+
try:
1333+
from mslk.quantize.triton.fp4_quantize import (
1334+
triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4,
1335+
)
1336+
1337+
_mslk_available = True
1338+
except ImportError:
1339+
_mslk_available = False
1340+
1341+
1342+
@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=())
1343+
def mslk_quantize_nvfp4(
1344+
x: torch.Tensor, global_scale: torch.Tensor
1345+
) -> Tuple[torch.Tensor, torch.Tensor]:
1346+
"""Quantize a tensor to NVFP4 using the MSLK triton kernel.
1347+
1348+
Args:
1349+
x: Input tensor to quantize.
1350+
global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale).
1351+
1352+
Returns:
1353+
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
1354+
"""
1355+
assert _mslk_available, (
1356+
"mslk is required for NVFP4 triton quantization. "
1357+
"Install from https://github.com/pytorch/MSLK"
1358+
)
1359+
data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale)
1360+
return blockwise_scales, data_lp.view(torch.uint8)
1361+
1362+
1363+
@mslk_quantize_nvfp4.register_fake
1364+
def _(x, global_scale):
1365+
# Mirror the reshape logic from the real MSLK kernel
1366+
orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1]
1367+
x_2d = x.reshape(-1, orig_N)
1368+
M, N = x_2d.shape
1369+
1370+
num_scales = N // 16
1371+
n_row_blocks = triton.cdiv(M, 128)
1372+
n_col_blocks = triton.cdiv(num_scales, 4)
1373+
padded_rows = n_row_blocks * 128
1374+
padded_cols = n_col_blocks * 4
1375+
1376+
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
1377+
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
1378+
1379+
# Reshape back to match original leading dims
1380+
scales = scales.view(*orig_leading_dims, -1, padded_cols)
1381+
xq = xq.view(*orig_leading_dims, -1, N // 2)
1382+
return scales, xq

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from torchao.prototype.mx_formats.kernels import (
1616
f4_unpacked_to_f32,
1717
f32_to_f4_unpacked,
18+
mslk_quantize_nvfp4,
1819
pack_uint4,
19-
triton_quantize_nvfp4,
2020
unpack_uint4,
2121
)
2222
from torchao.prototype.mx_formats.mx_tensor import (
@@ -155,7 +155,10 @@ def to_nvfp4(
155155
assert K % 16 == 0, (
156156
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
157157
)
158-
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
158+
assert per_tensor_scale is not None, (
159+
"Triton kernel requires per_tensor_scale"
160+
)
161+
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
159162
else:
160163
blockwise_scales, data_lp = nvfp4_quantize(
161164
data_hp, block_size, per_tensor_scale
@@ -245,7 +248,7 @@ def get_hp_scales(self) -> torch.Tensor:
245248
return (
246249
scale_e4m3.to(self.orig_dtype)
247250
if self.per_tensor_scale is None
248-
else self.per_tensor_scale * scale_e4m3.to(self.orig_dtype)
251+
else scale_e4m3.to(self.orig_dtype) / self.per_tensor_scale
249252
)
250253

251254
@classmethod
@@ -465,7 +468,7 @@ def _addmm_nvfp4_dispatch(
465468
# Merge double quant scales into 1 scale for Scale_In^D
466469
if a.per_tensor_scale is not None:
467470
assert b.per_tensor_scale is not None
468-
scale_result = a.per_tensor_scale * b.per_tensor_scale
471+
scale_result = 1.0 / (a.per_tensor_scale * b.per_tensor_scale)
469472
else:
470473
assert b.per_tensor_scale is None and a.per_tensor_scale is None
471474
scale_result = None
@@ -625,17 +628,17 @@ def nvfp4_addmm(func, types, args, kwargs):
625628
def per_tensor_amax_to_scale(amax: torch.Tensor) -> torch.Tensor:
626629
"""Convert per-tensor amax to per-tensor scale for NVFP4 quantization.
627630
628-
Divides by both F8E4M3_MAX and F4_E2M1_MAX to ensure block scales can utilize
629-
the full FP8 E4M3 range (up to 448) when block_max equals tensor_max.
630-
Without F4_E2M1_MAX, the maximum scale would only reach FP8_MAX / FP4_MAX.
631+
Returns the global scale in MSLK convention: (F8E4M3_MAX * F4_E2M1_MAX) / amax.
632+
This ensures block scales can utilize the full FP8 E4M3 range (up to 448)
633+
when block_max equals tensor_max.
631634
632635
Args:
633636
amax: Per-tensor absolute maximum value from calibration
634637
635638
Returns:
636639
torch.Tensor: Per-tensor scale for two-level NVFP4 scaling
637640
"""
638-
return amax.to(torch.float32) / (F8E4M3_MAX * F4_E2M1_MAX)
641+
return (F8E4M3_MAX * F4_E2M1_MAX) / amax.to(torch.float32)
639642

640643

641644
def nvfp4_quantize(
@@ -694,15 +697,15 @@ def nvfp4_quantize(
694697
# we want the per_tensor_scale ~= amax of the block_scale_fp32
695698
block_scale_fp32 = block_scale.to(torch.float32)
696699
# Quantize the blockwise scales w/ the per_tensor_scale
697-
scaled_block_scales = block_scale_fp32 / per_tensor_scale
700+
scaled_block_scales = block_scale_fp32 * per_tensor_scale
698701
scaled_block_scales_fp8 = torch.clamp(
699702
scaled_block_scales, min=E4M3_EPS, max=F8E4M3_MAX
700703
).to(torch.float8_e4m3fn)
701704
scaled_block_scales_fp32 = scaled_block_scales_fp8.to(torch.float32)
702-
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
703-
# To apply to data
704-
total_scale = per_tensor_scale * scaled_block_scales_fp32
705-
data_scaled = data_hp / total_scale.unsqueeze(-1)
705+
# Multiply by reciprocal of combined scale instead of dividing,
706+
# to match the MSLK triton kernel numerics: x * (global_scale / fp8_scale)
707+
reciprocal_scale = per_tensor_scale / scaled_block_scales_fp32
708+
data_scaled = data_hp * reciprocal_scale.unsqueeze(-1)
706709
out_scales = scaled_block_scales_fp8
707710

708711
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)

0 commit comments

Comments
 (0)