Skip to content

Commit a302c10

Browse files
authored
[nvfp4] Make per_tensor_scale optional for triton kernel path (#4188)
* [nvfp4] Make per_tensor_scale optional for triton kernel path Summary: MSLK now supports optional global scale in its triton quantize kernel (MSLK#233, commit c01f06c). This change relaxes the corresponding constraint in torchao so the triton kernel path can be used without a per_tensor_scale (single-level block-wise scaling only). Changes: - Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch - Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`, passing `None` through to MSLK (which treats it as global_scale=1.0) - Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between operands (treat None as 1.0) instead of asserting both-or-neither Test Plan: Requires SM100+ GPU with MSLK nightly installed. ``` python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v ``` Performance: with global scale: ``` python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True Parameter Value ---------------------- ------------------------ GPU NVIDIA GB200 torch version 2.12.0.dev20260316+cu128 torchao version 0.17.0+git95281b63b recipe_name nvfp4 do_benchmarks True shape_gen_name pow2 enable_fusion_modeling True op_name linear MKN None None None DHW None None None kernel_size stride 1 padding 0 bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N) bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M) fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16)) fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16)) fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.45 1 2048 2048 2048 2.39 0.66 2 4096 4096 4096 2.92 1.29 3 8192 8192 8192 3.34 1.74 4 16384 16384 16384 3.63 2.84 ``` without global scale: ``` python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True Parameter Value ---------------------- ------------------------ GPU NVIDIA GB200 torch version 2.12.0.dev20260316+cu128 torchao version 0.17.0+gitabb103d3b recipe_name nvfp4_no_global_scale do_benchmarks True shape_gen_name pow2 enable_fusion_modeling True op_name linear MKN None None None DHW None None None kernel_size stride 1 padding 0 bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N) bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M) fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16)) fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16)) fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp 0 1024 1024 1024 1.00 0.73 1 2048 2048 2048 2.71 1.09 2 4096 4096 4096 3.44 2.22 3 8192 8192 8192 3.68 2.82 4 16384 16384 16384 3.83 3.65 ``` [ghstack-poisoned]
1 parent b1ddd15 commit a302c10

7 files changed

Lines changed: 107 additions & 24 deletions

File tree

benchmarks/float8/float8_inference_roofline.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,12 @@ def get_gemm_times(
112112

113113
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
114114

115-
if recipe_name in ("mxfp4_cutlass", "nvfp4", "nvfp4_static"):
115+
if recipe_name in (
116+
"mxfp4_cutlass",
117+
"nvfp4",
118+
"nvfp4_static",
119+
"nvfp4_no_global_scale",
120+
):
116121
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
117122
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
118123
d1
@@ -151,7 +156,7 @@ def get_gemm_times(
151156
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
152157
scale_a = to_blocked(scale_a)
153158
scale_b = to_blocked(scale_b)
154-
elif recipe_name in ("nvfp4", "nvfp4_static"):
159+
elif recipe_name in ("nvfp4", "nvfp4_static", "nvfp4_no_global_scale"):
155160
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
156161
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
157162
scale_a = to_blocked(scale_a)
@@ -177,7 +182,7 @@ def do_matmul(A, B):
177182
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
178183
output_dtype=d3,
179184
)
180-
if recipe_name in ("nvfp4", "nvfp4_static"):
185+
if recipe_name in ("nvfp4", "nvfp4_static", "nvfp4_no_global_scale"):
181186
return torch._scaled_mm(
182187
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
183188
)
@@ -797,6 +802,10 @@ def run(
797802
config = NVFP4DynamicActivationNVFP4WeightConfig(
798803
use_dynamic_per_tensor_scale=True,
799804
)
805+
elif recipe_name == "nvfp4_no_global_scale":
806+
config = NVFP4DynamicActivationNVFP4WeightConfig(
807+
use_dynamic_per_tensor_scale=False,
808+
)
800809
elif recipe_name == "nvfp4_static":
801810
config_calib = NVFP4DynamicActivationNVFP4WeightConfig(
802811
step="prepare",

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def test_inference_workflow_nvfp4(
181181
pytest.skip("TODO: weight_only quant currently errors w/ compile")
182182
if quant_type == "weight_only" and use_triton_kernel:
183183
pytest.skip("unsupported configuration")
184-
if use_triton_kernel and not use_dynamic_per_tensor_scale:
185-
pytest.skip("unsupported configuration")
186184

187185
if use_inference_mode and (
188186
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,9 +369,6 @@ def test_nvfp4_swizzled_scales_get_scales_method():
369369
@torch.no_grad()
370370
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
371371
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372-
if not use_per_tensor_scale:
373-
pytest.skip("MSLK triton kernel requires per_tensor_scale")
374-
375372
torch.manual_seed(42)
376373
x = torch.randn(M, N, dtype=dtype, device="cuda")
377374

@@ -657,3 +654,61 @@ def test_nvfp4_pin_memory(use_per_tensor_scale):
657654
assert torch.equal(
658655
x_cpu.dequantize(torch.float32), x_pinned.dequantize(torch.float32)
659656
)
657+
658+
659+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
660+
@pytest.mark.skipif(
661+
not is_sm_at_least_100(), reason="requires sm100+ for nvfp4 triton kernel"
662+
)
663+
@pytest.mark.parametrize(
664+
"shapes",
665+
[
666+
(128, 64, 256),
667+
(256, 128, 512),
668+
],
669+
ids=lambda s: f"{s[0]}x{s[1]}x{s[2]}",
670+
)
671+
@pytest.mark.parametrize(
672+
"a_has_scale",
673+
[True, False],
674+
ids=["a_scale", "no_a_scale"],
675+
)
676+
@pytest.mark.parametrize("use_triton_kernel", [True, False])
677+
@torch.no_grad()
678+
@skip_if_rocm("ROCm float4 gemm require gfx950")
679+
def test_nvfp4_matmul_optional_per_tensor_scale(shapes, a_has_scale, use_triton_kernel):
680+
"""Test NVFP4 matmul works when per_tensor_scale is None for activation but always set for weight."""
681+
m, k, n = shapes
682+
683+
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
684+
B = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
685+
686+
C_ref = F.linear(A, B)
687+
688+
a_scale = (
689+
per_tensor_amax_to_scale(torch.amax(torch.abs(A))) if a_has_scale else None
690+
)
691+
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
692+
693+
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
694+
695+
A_nvfp4 = NVFP4Tensor.to_nvfp4(
696+
A,
697+
per_tensor_scale=a_scale,
698+
is_swizzled_scales=True,
699+
use_triton_kernel=use_triton_kernel,
700+
)
701+
B_nvfp4 = NVFP4Tensor.to_nvfp4(
702+
B,
703+
per_tensor_scale=b_scale,
704+
is_swizzled_scales=True,
705+
use_triton_kernel=use_triton_kernel,
706+
act_quant_kwargs=act_quant_kwargs,
707+
)
708+
709+
C_nvfp4 = F.linear(A_nvfp4, B_nvfp4)
710+
assert C_nvfp4.dtype == torch.bfloat16
711+
712+
sqnr = compute_error(C_ref, C_nvfp4)
713+
SQNR_THRESHOLD = 16.0
714+
assert sqnr >= SQNR_THRESHOLD, f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, {a_has_scale=}"

torchao/prototype/mx_formats/kernels.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import importlib
88
import logging
9-
from typing import Tuple
9+
from typing import Optional, Tuple
1010

1111
import numpy as np
1212
import torch
@@ -22,6 +22,7 @@
2222
from torchao.utils import (
2323
is_cuda_version_at_least,
2424
is_MI350,
25+
is_mslk_version_at_least,
2526
is_ROCM,
2627
is_sm_at_least_100,
2728
torch_version_at_least,
@@ -1175,30 +1176,34 @@ def mxfp8_quantize_cuda(
11751176

11761177

11771178
def mslk_quantize_nvfp4(
1178-
x: torch.Tensor, per_tensor_scale: torch.Tensor
1179+
x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None
11791180
) -> Tuple[torch.Tensor, torch.Tensor]:
11801181
"""Quantize a tensor to NVFP4 using the MSLK triton kernel.
11811182
11821183
Args:
11831184
x: Input tensor to quantize.
1184-
per_tensor_scale: Per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)).
1185+
per_tensor_scale: Optional per-tensor scale (TorchAO convention: amax / (F8E4M3_MAX * F4_E2M1_MAX)).
1186+
If None, the global scale is not applied (single-level block-wise scaling only).
11851187
11861188
Returns:
11871189
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
11881190
"""
1189-
mslk_global_scale = per_tensor_scale.reciprocal()
1191+
mslk_global_scale = (
1192+
per_tensor_scale.reciprocal() if per_tensor_scale is not None else None
1193+
)
11901194
return _mslk_quantize_nvfp4_custom_op(x, mslk_global_scale)
11911195

11921196

11931197
@torch.library.custom_op("ao::mslk_quantize_nvfp4", mutates_args=())
11941198
def _mslk_quantize_nvfp4_custom_op(
1195-
x: torch.Tensor, global_scale: torch.Tensor
1199+
x: torch.Tensor, global_scale: Optional[torch.Tensor] = None
11961200
) -> Tuple[torch.Tensor, torch.Tensor]:
11971201
"""Inner custom op for MSLK NVFP4 quantization.
11981202
11991203
Args:
12001204
x: Input tensor to quantize.
1201-
global_scale: Global scale in MSLK convention (1.0 / per_tensor_scale).
1205+
global_scale: Optional global scale in MSLK convention (1.0 / per_tensor_scale).
1206+
If None, the global scale is not applied (treated as 1.0).
12021207
12031208
Returns:
12041209
Tuple of (blockwise_scales, quantized_data_uint8) matching TorchAO's convention.
@@ -1211,12 +1216,18 @@ def _mslk_quantize_nvfp4_custom_op(
12111216
triton_quantize_nvfp4 as _mslk_triton_quantize_nvfp4,
12121217
)
12131218

1219+
if global_scale is None:
1220+
assert is_mslk_version_at_least("1.1.0"), (
1221+
"Optional global_scale support requires MSLK >= 1.1.0, "
1222+
"Please upgrade MSLK: https://github.com/pytorch/MSLK"
1223+
)
1224+
12141225
data_lp, blockwise_scales = _mslk_triton_quantize_nvfp4(x, global_scale)
12151226
return blockwise_scales, data_lp.view(torch.uint8)
12161227

12171228

12181229
@_mslk_quantize_nvfp4_custom_op.register_fake
1219-
def _(x, global_scale):
1230+
def _(x, global_scale=None):
12201231
# Mirror the reshape logic from the real MSLK kernel
12211232
orig_leading_dims, orig_N = x.shape[:-2], x.shape[-1]
12221233
x_2d = x.reshape(-1, orig_N)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ def to_nvfp4(
155155
assert K % 16 == 0, (
156156
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
157157
)
158-
assert per_tensor_scale is not None, (
159-
"Triton kernel requires per_tensor_scale"
160-
)
161158
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
162159
else:
163160
blockwise_scales, data_lp = nvfp4_quantize(
@@ -497,11 +494,14 @@ def _addmm_nvfp4_dispatch(
497494
b_scale_blocked = to_blocked(b_scale)
498495

499496
# Merge double quant scales into 1 scale for Scale_In^D
500-
if a.per_tensor_scale is not None:
501-
assert b.per_tensor_scale is not None
502-
scale_result = a.per_tensor_scale * b.per_tensor_scale
497+
# When per_tensor_scale is None for an operand, it's treated as 1.0
498+
a_scale = a.per_tensor_scale
499+
b_scale = b.per_tensor_scale
500+
if a_scale is not None and b_scale is not None:
501+
scale_result = a_scale * b_scale
502+
elif a_scale is not None or b_scale is not None:
503+
scale_result = a_scale if a_scale is not None else b_scale
503504
else:
504-
assert b.per_tensor_scale is None and a.per_tensor_scale is None
505505
scale_result = None
506506

507507
# THIS IS A WORKAROUND FOR TWO ERRORS:
@@ -720,7 +720,9 @@ def nvfp4_quantize(
720720
torch.float8_e4m3fn
721721
)
722722
block_scale_fp32 = block_scale_fp8.to(torch.float32)
723-
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
723+
# Multiply by reciprocal instead of dividing to match MSLK triton kernel
724+
# numerics (global_scale=None treated as 1.0): x * (1.0 / fp8_scale)
725+
data_scaled = data_hp * (1.0 / block_scale_fp32).unsqueeze(-1)
724726
out_scales = block_scale_fp8
725727
else:
726728
# We are doing two level scaling,

torchao/testing/training/roofline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def get_inference_tensor_memory_traffic_ovhd_s(
519519
)
520520
res_bytes = [kernel_1_rw + kernel_3_rw]
521521

522-
case "nvfp4_static":
522+
case "nvfp4_static" | "nvfp4_no_global_scale":
523523
# nvfp4 with static global scaling
524524
# x_b16 = ...
525525
# static_max_abs = ...

torchao/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,14 @@ def _is_mslk_available():
12771277
return True
12781278

12791279

1280+
def is_mslk_version_at_least(min_version: str) -> bool:
1281+
if not _is_mslk_available():
1282+
return False
1283+
import mslk
1284+
1285+
return parse_version(mslk.__version__) >= parse_version(min_version)
1286+
1287+
12801288
def _is_flashinfer_available():
12811289
return (
12821290
# flashinfer-python

0 commit comments

Comments
 (0)