Skip to content

Commit 65e00ab

Browse files
committed
Add support for flashinfer quantize kernel option for nvfp4
Summary: Added the flashinfer option for better performance on some of the workflow we are interested in, also added numerical equivalence test between different quantize_kernel_preference options Test Plan: pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4e6c022 Pull Request resolved: #3912
1 parent 253dce3 commit 65e00ab

4 files changed

Lines changed: 150 additions & 55 deletions

File tree

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchao.quantization.utils import compute_error
2424
from torchao.testing.utils import skip_if_rocm
2525
from torchao.utils import (
26+
_is_flashinfer_available,
2627
is_sm_at_least_100,
2728
torch_version_at_least,
2829
)
@@ -368,8 +369,10 @@ def test_nvfp4_swizzled_scales_get_scales_method():
368369
not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics"
369370
)
370371
@torch.no_grad()
371-
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
372-
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+
def test_quantize_to_nvfp4_kernel_numerical_equivalence(
373+
M, N, use_per_tensor_scale, dtype
374+
):
375+
"""Test that different quantize_to_nvfp4 kernel choices produce numerically equivalent results."""
373376
if not use_per_tensor_scale:
374377
pytest.skip("MSLK triton kernel requires per_tensor_scale")
375378

@@ -380,40 +383,55 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
380383
if use_per_tensor_scale:
381384
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x)))
382385

383-
nvfp4_pt = NVFP4Tensor.to_nvfp4(
386+
# Reference: TORCH kernel choice
387+
nvfp4_ref = NVFP4Tensor.to_nvfp4(
384388
x.clone(),
385389
per_tensor_scale=per_tensor_scale,
386390
is_swizzled_scales=True,
387391
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
388392
)
389-
390-
nvfp4_triton = NVFP4Tensor.to_nvfp4(
391-
x.clone(),
392-
per_tensor_scale=per_tensor_scale,
393-
is_swizzled_scales=True,
394-
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK,
395-
)
396-
397-
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
398-
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
399-
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
400-
torch.testing.assert_close(
401-
pt_unpacked,
402-
triton_unpacked,
403-
atol=0,
404-
rtol=0,
405-
)
406-
407-
x_pt_dequant = nvfp4_pt.dequantize(dtype)
408-
x_triton_dequant = nvfp4_triton.dequantize(dtype)
409-
410-
sqnr = compute_error(x_pt_dequant, x_triton_dequant)
411-
SQNR_THRESHOLD = 40.0
412-
413-
assert sqnr >= SQNR_THRESHOLD, (
414-
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} for M={M}, N={N}, "
415-
f"use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}"
416-
)
393+
ref_dequant = nvfp4_ref.dequantize(dtype)
394+
395+
other_kernel_choices = [QuantizeToNVFP4KernelChoice.MSLK]
396+
397+
# Flashinfer requires the library and per_tensor_scale
398+
if _is_flashinfer_available() and use_per_tensor_scale:
399+
other_kernel_choices.append(QuantizeToNVFP4KernelChoice.FLASHINFER)
400+
401+
SQNR_THRESHOLD = 28.0
402+
for kc in other_kernel_choices:
403+
nvfp4_other = NVFP4Tensor.to_nvfp4(
404+
x.clone(),
405+
per_tensor_scale=per_tensor_scale,
406+
is_swizzled_scales=True,
407+
quantize_to_nvfp4_kernel_choice=kc,
408+
)
409+
410+
# For kernel choices that use the same quantization algorithm as TORCH
411+
# (MSLK should be bitwise identical), verify internal data matches exactly
412+
if kc == QuantizeToNVFP4KernelChoice.MSLK:
413+
torch.testing.assert_close(
414+
nvfp4_ref.scale.flatten(),
415+
nvfp4_other.scale.flatten(),
416+
atol=0,
417+
rtol=0,
418+
)
419+
ref_unpacked = unpack_uint4(nvfp4_ref.qdata)
420+
other_unpacked = unpack_uint4(nvfp4_other.qdata)
421+
torch.testing.assert_close(
422+
ref_unpacked,
423+
other_unpacked,
424+
atol=0,
425+
rtol=0,
426+
)
427+
428+
# Verify dequantized values are numerically close for all kernel choices
429+
other_dequant = nvfp4_other.dequantize(dtype)
430+
sqnr = compute_error(ref_dequant, other_dequant)
431+
assert sqnr >= SQNR_THRESHOLD, (
432+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} between TORCH and {kc}, "
433+
f"M={M}, N={N}, use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}"
434+
)
417435

418436

419437
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -430,7 +448,11 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
430448
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
431449
@pytest.mark.parametrize(
432450
"quantize_to_nvfp4_kernel_choice",
433-
[QuantizeToNVFP4KernelChoice.MSLK, QuantizeToNVFP4KernelChoice.TORCH],
451+
[
452+
QuantizeToNVFP4KernelChoice.MSLK,
453+
QuantizeToNVFP4KernelChoice.FLASHINFER,
454+
QuantizeToNVFP4KernelChoice.TORCH,
455+
],
434456
)
435457
@pytest.mark.parametrize(
436458
"shapes",
@@ -469,6 +491,10 @@ def test_nvfp4_matmul_with_amax(
469491
if quant_type == "weight_only" and compile:
470492
pytest.skip("TODO: weight_only currently errors w/ compile")
471493

494+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
495+
if not _is_flashinfer_available():
496+
pytest.skip("flashinfer not available")
497+
472498
m, k, n = shapes
473499

474500
# Create activation tensor
@@ -544,7 +570,11 @@ def test_nvfp4_to_copy():
544570
@pytest.mark.parametrize("transpose", [False, True])
545571
@pytest.mark.parametrize(
546572
"quantize_to_nvfp4_kernel_choice",
547-
[QuantizeToNVFP4KernelChoice.TORCH, QuantizeToNVFP4KernelChoice.MSLK],
573+
[
574+
QuantizeToNVFP4KernelChoice.TORCH,
575+
QuantizeToNVFP4KernelChoice.MSLK,
576+
QuantizeToNVFP4KernelChoice.FLASHINFER,
577+
],
548578
)
549579
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
550580
@pytest.mark.parametrize(
@@ -570,11 +600,22 @@ def test_scale_shape_matches_qdata(
570600
and not is_swizzled_scales
571601
):
572602
pytest.skip("triton kernel requires swizzled scales")
603+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
604+
if not _is_flashinfer_available():
605+
pytest.skip("flashinfer not available")
606+
if not is_swizzled_scales:
607+
pytest.skip("flashinfer requires swizzled scales")
608+
if shape[-1] % 64 != 0:
609+
pytest.skip("flashinfer requires K to be divisible by 64")
573610

574611
block_size = 16
575612

576613
x_hp = torch.randn(*shape, device="cuda")
577614

615+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
616+
# flashinfer only supports fp16/bf16/e4m3 input
617+
x_hp = x_hp.to(torch.bfloat16)
618+
578619
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))
579620

580621
x = NVFP4Tensor.to_nvfp4(

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class QuantizeToNVFP4KernelChoice(str, Enum):
4343
MSLK = "mslk"
4444
"""Use MSLK triton high precision to nvfp4 quantize kernel"""
4545

46+
FLASHINFER = "flashinfer"
47+
"""Use flashinfer bf16 to nvfp4 quantize kernel"""
48+
4649

4750
torch.serialization.add_safe_globals([QuantizeToNVFP4KernelChoice])
4851

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import types
8+
import warnings
89
from dataclasses import dataclass
910
from functools import partial
1011
from typing import Optional
@@ -208,7 +209,7 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
208209
set to False.
209210
210211
Configuration parameters:
211-
- quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice, kernel preference for quantization (default: QuantizeToNVFP4KernelChoice.MSLK)
212+
- quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice, kernel choice for quantization (default: QuantizeToNVFP4KernelChoice.MSLK)
212213
Requires `MSLK <https://github.com/pytorch/MSLK>`__ to be installed.
213214
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
214215
- step: Optional[QuantizationStep], the quantization step for observer-based flow
@@ -249,6 +250,17 @@ def __post_init__(self):
249250
# Static quantization implies use_dynamic_per_tensor_scale=False
250251
self.use_dynamic_per_tensor_scale = False
251252

253+
if (
254+
self.quantize_to_nvfp4_kernel_choice
255+
== QuantizeToNVFP4KernelChoice.FLASHINFER
256+
):
257+
if self.step is None and not self.use_dynamic_per_tensor_scale:
258+
raise ValueError(
259+
"FLASHINFER kernel choice requires per_tensor_scale. "
260+
"Use step='prepare'/'convert' for static quantization, "
261+
"or set use_dynamic_per_tensor_scale=True."
262+
)
263+
252264

253265
@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig)
254266
def _nvfp4_inference_linear_transform(
@@ -269,6 +281,15 @@ def _nvfp4_inference_linear_transform(
269281
raise RuntimeError(
270282
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
271283
)
284+
if (
285+
config.quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER
286+
and weight.shape[-1] % 64 != 0
287+
):
288+
warnings.warn(
289+
f"Skipping NVFP4 quantization for layer with K={weight.shape[-1]}: "
290+
f"flashinfer requires K to be divisible by 64."
291+
)
292+
return module
272293

273294
step = config.step
274295
if step == QuantizationStep.PREPARE or step == "prepare":

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torchao.quantization.quantize_.common import (
3434
QuantizeTensorKwargs,
3535
)
36-
from torchao.utils import TorchAOBaseTensor, fill_defaults
36+
from torchao.utils import TorchAOBaseTensor, _is_flashinfer_available, fill_defaults
3737

3838
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
3939

@@ -46,25 +46,30 @@ def _handle_use_triton_kernel(
4646
) -> QuantizeToNVFP4KernelChoice:
4747
"""Handle deprecated use_triton_kernel parameter.
4848
49-
Raises an exception if use_triton_kernel does not match
50-
quantize_to_nvfp4_kernel_choice.
49+
Raises ValueError if use_triton_kernel does not match
50+
quantize_to_nvfp4_kernel_choice. use_triton_kernel=True corresponds to
51+
MSLK, use_triton_kernel=False corresponds to TORCH or FLASHINFER.
5152
"""
52-
expected = (
53-
QuantizeToNVFP4KernelChoice.MSLK
54-
if use_triton_kernel
55-
else QuantizeToNVFP4KernelChoice.TORCH
56-
)
57-
if expected != quantize_to_nvfp4_kernel_choice:
58-
raise ValueError(
59-
f"`use_triton_kernel={use_triton_kernel}` does not match "
60-
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
61-
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
62-
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
63-
"`use_triton_kernel=True` is equivalent to "
64-
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK`, "
65-
"`use_triton_kernel=False` is equivalent to "
66-
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH`."
67-
)
53+
if use_triton_kernel:
54+
if quantize_to_nvfp4_kernel_choice != QuantizeToNVFP4KernelChoice.MSLK:
55+
raise ValueError(
56+
f"`use_triton_kernel=True` does not match "
57+
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
58+
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
59+
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
60+
"`use_triton_kernel=True` is equivalent to "
61+
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK`."
62+
)
63+
else:
64+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK:
65+
raise ValueError(
66+
f"`use_triton_kernel=False` does not match "
67+
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
68+
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
69+
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
70+
"`use_triton_kernel=False` is equivalent to "
71+
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH`."
72+
)
6873
return quantize_to_nvfp4_kernel_choice
6974

7075

@@ -98,7 +103,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
98103
block_size (int): Block size for quantization (fixed at 16)
99104
orig_dtype (torch.dtype): Original tensor dtype before quantization
100105
is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
101-
quantize_to_nvfp4_kernel_choice (QuantizeToNVFP4KernelChoice): Kernel preference for quantization
106+
quantize_to_nvfp4_kernel_choice (QuantizeToNVFP4KernelChoice): Kernel choice for quantization
102107
"""
103108

104109
tensor_data_names = ["qdata", "scale"]
@@ -179,7 +184,7 @@ def to_nvfp4(
179184
act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
180185
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
181186
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
182-
quantize_to_nvfp4_kernel_choice: Kernel preference for quantization
187+
quantize_to_nvfp4_kernel_choice: Kernel choice for quantization
183188
act_quant_kwargs: If specified, config for quantizing the activation
184189
185190
Returns:
@@ -201,6 +206,31 @@ def to_nvfp4(
201206
"Triton kernel requires per_tensor_scale"
202207
)
203208
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
209+
elif quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
210+
from flashinfer import SfLayout
211+
from flashinfer import nvfp4_quantize as flashinfer_nvfp4_quantize
212+
213+
assert _is_flashinfer_available(), (
214+
"flashinfer is not available, please install flashinfer-python, apache-tvm-ffi, and nvidia-ml-py to use FLASHINFER kernel choice"
215+
)
216+
assert per_tensor_scale is not None, (
217+
"flashinfer nvfp4_quantize requires per_tensor_scale"
218+
)
219+
assert is_swizzled_scales, (
220+
"flashinfer nvfp4_quantize only supports swizzled scales"
221+
)
222+
assert K % 64 == 0, (
223+
f"flashinfer nvfp4_quantize requires K (dim -1) to be divisible by 64, got {K}"
224+
)
225+
# flashinfer uses global_sf = (F8E4M3_MAX * F4_E2M1_MAX) / amax
226+
# which is 1 / per_tensor_scale
227+
global_sf = 1.0 / per_tensor_scale
228+
data_lp, blockwise_scales = flashinfer_nvfp4_quantize(
229+
data_hp,
230+
global_sf,
231+
sfLayout=SfLayout.layout_128x4,
232+
do_shuffle=False,
233+
)
204234
elif quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.TORCH:
205235
blockwise_scales, data_lp = nvfp4_quantize(
206236
data_hp, block_size, per_tensor_scale

0 commit comments

Comments
 (0)