Skip to content

Commit 488e447

Browse files
committed
Add support for flashinfer quantize kernel option for nvfp4
Summary: Added the flashinfer option for better performance on some of the workflow we are interested in, also added numerical equivalence test between different quantize_kernel_preference options Test Plan: pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k test_kernel_preference_numerical_equivalence Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 0db4b3f Pull Request resolved: #3912
1 parent 32f6744 commit 488e447

4 files changed

Lines changed: 125 additions & 25 deletions

File tree

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchao.quantization.utils import compute_error
2424
from torchao.testing.utils import skip_if_rocm
2525
from torchao.utils import (
26+
_is_flashinfer_available,
2627
is_sm_at_least_100,
2728
torch_version_at_least,
2829
)
@@ -368,8 +369,10 @@ def test_nvfp4_swizzled_scales_get_scales_method():
368369
not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics"
369370
)
370371
@torch.no_grad()
371-
def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
372-
"""Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+
def test_quantize_to_nvfp4_kernel_numerical_equivalence(
373+
M, N, use_per_tensor_scale, dtype
374+
):
375+
"""Test that different quantize_to_nvfp4 kernel choices produce numerically equivalent results."""
373376
if not use_per_tensor_scale:
374377
pytest.skip("MSLK triton kernel requires per_tensor_scale")
375378

@@ -380,19 +383,16 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
380383
if use_per_tensor_scale:
381384
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x)))
382385

383-
nvfp4_pt = NVFP4Tensor.to_nvfp4(
386+
# Reference: TORCH kernel choice
387+
nvfp4_ref = NVFP4Tensor.to_nvfp4(
384388
x.clone(),
385389
per_tensor_scale=per_tensor_scale,
386390
is_swizzled_scales=True,
387391
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
388392
)
393+
ref_dequant = nvfp4_ref.dequantize(dtype)
389394

390-
nvfp4_triton = NVFP4Tensor.to_nvfp4(
391-
x.clone(),
392-
per_tensor_scale=per_tensor_scale,
393-
is_swizzled_scales=True,
394-
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK,
395-
)
395+
other_kernel_choices = [QuantizeToNVFP4KernelChoice.MSLK]
396396

397397
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
398398
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
@@ -404,16 +404,44 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
404404
rtol=0,
405405
)
406406

407-
x_pt_dequant = nvfp4_pt.dequantize(dtype)
408-
x_triton_dequant = nvfp4_triton.dequantize(dtype)
409-
410-
sqnr = compute_error(x_pt_dequant, x_triton_dequant)
411-
SQNR_THRESHOLD = 40.0
412-
413-
assert sqnr >= SQNR_THRESHOLD, (
414-
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} for M={M}, N={N}, "
415-
f"use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}"
416-
)
407+
# Flashinfer requires the library and per_tensor_scale
408+
if _is_flashinfer_available() and use_per_tensor_scale:
409+
other_kernel_choices.append(QuantizeToNVFP4KernelChoice.FLASHINFER)
410+
411+
SQNR_THRESHOLD = 28.0
412+
for kc in other_kernel_choices:
413+
nvfp4_other = NVFP4Tensor.to_nvfp4(
414+
x.clone(),
415+
per_tensor_scale=per_tensor_scale,
416+
is_swizzled_scales=True,
417+
quantize_to_nvfp4_kernel_choice=kc,
418+
)
419+
420+
# For kernel choices that use the same quantization algorithm as TORCH
421+
# (MSLK should be bitwise identical), verify internal data matches exactly
422+
if kc == QuantizeToNVFP4KernelChoice.MSLK:
423+
torch.testing.assert_close(
424+
nvfp4_ref.scale.flatten(),
425+
nvfp4_other.scale.flatten(),
426+
atol=0,
427+
rtol=0,
428+
)
429+
ref_unpacked = unpack_uint4(nvfp4_ref.qdata)
430+
other_unpacked = unpack_uint4(nvfp4_other.qdata)
431+
torch.testing.assert_close(
432+
ref_unpacked,
433+
other_unpacked,
434+
atol=0,
435+
rtol=0,
436+
)
437+
438+
# Verify dequantized values are numerically close for all kernel choices
439+
other_dequant = nvfp4_other.dequantize(dtype)
440+
sqnr = compute_error(ref_dequant, other_dequant)
441+
assert sqnr >= SQNR_THRESHOLD, (
442+
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD} between TORCH and {kc}, "
443+
f"M={M}, N={N}, use_per_tensor_scale={use_per_tensor_scale}, dtype={dtype}"
444+
)
417445

418446

419447
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -430,7 +458,11 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
430458
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
431459
@pytest.mark.parametrize(
432460
"quantize_to_nvfp4_kernel_choice",
433-
[QuantizeToNVFP4KernelChoice.MSLK, QuantizeToNVFP4KernelChoice.TORCH],
461+
[
462+
QuantizeToNVFP4KernelChoice.MSLK,
463+
QuantizeToNVFP4KernelChoice.FLASHINFER,
464+
QuantizeToNVFP4KernelChoice.TORCH,
465+
],
434466
)
435467
@pytest.mark.parametrize(
436468
"shapes",
@@ -469,6 +501,10 @@ def test_nvfp4_matmul_with_amax(
469501
if quant_type == "weight_only" and compile:
470502
pytest.skip("TODO: weight_only currently errors w/ compile")
471503

504+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
505+
if not _is_flashinfer_available():
506+
pytest.skip("flashinfer not available")
507+
472508
m, k, n = shapes
473509

474510
# Create activation tensor
@@ -544,7 +580,11 @@ def test_nvfp4_to_copy():
544580
@pytest.mark.parametrize("transpose", [False, True])
545581
@pytest.mark.parametrize(
546582
"quantize_to_nvfp4_kernel_choice",
547-
[QuantizeToNVFP4KernelChoice.TORCH, QuantizeToNVFP4KernelChoice.MSLK],
583+
[
584+
QuantizeToNVFP4KernelChoice.TORCH,
585+
QuantizeToNVFP4KernelChoice.MSLK,
586+
QuantizeToNVFP4KernelChoice.FLASHINFER,
587+
],
548588
)
549589
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
550590
@pytest.mark.parametrize(
@@ -570,11 +610,22 @@ def test_scale_shape_matches_qdata(
570610
and not is_swizzled_scales
571611
):
572612
pytest.skip("triton kernel requires swizzled scales")
613+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
614+
if not _is_flashinfer_available():
615+
pytest.skip("flashinfer not available")
616+
if not is_swizzled_scales:
617+
pytest.skip("flashinfer requires swizzled scales")
618+
if shape[-1] % 64 != 0:
619+
pytest.skip("flashinfer requires K to be divisible by 64")
573620

574621
block_size = 16
575622

576623
x_hp = torch.randn(*shape, device="cuda")
577624

625+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
626+
# flashinfer only supports fp16/bf16/e4m3 input
627+
x_hp = x_hp.to(torch.bfloat16)
628+
578629
per_tensor_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(x_hp)))
579630

580631
x = NVFP4Tensor.to_nvfp4(

torchao/prototype/mx_formats/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ class QuantizeToNVFP4KernelChoice(str, Enum):
4343
MSLK = "mslk"
4444
"""Use MSLK triton high precision to nvfp4 quantize kernel"""
4545

46+
FLASHINFER = "flashinfer"
47+
"""Use flashinfer bf16 to nvfp4 quantize kernel"""
48+
4649

4750
torch.serialization.add_safe_globals([QuantizeToNVFP4KernelChoice])
4851

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import types
8+
import warnings
89
from dataclasses import dataclass
910
from functools import partial
1011
from typing import Optional
@@ -208,7 +209,7 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
208209
set to False.
209210
210211
Configuration parameters:
211-
- quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice, kernel preference for quantization (default: QuantizeToNVFP4KernelChoice.MSLK)
212+
- quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice, kernel choice for quantization (default: QuantizeToNVFP4KernelChoice.MSLK)
212213
Requires `MSLK <https://github.com/pytorch/MSLK>`__ to be installed.
213214
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
214215
- step: Optional[QuantizationStep], the quantization step for observer-based flow
@@ -249,6 +250,17 @@ def __post_init__(self):
249250
# Static quantization implies use_dynamic_per_tensor_scale=False
250251
self.use_dynamic_per_tensor_scale = False
251252

253+
if (
254+
self.quantize_to_nvfp4_kernel_choice
255+
== QuantizeToNVFP4KernelChoice.FLASHINFER
256+
):
257+
if self.step is None and not self.use_dynamic_per_tensor_scale:
258+
raise ValueError(
259+
"FLASHINFER kernel choice requires per_tensor_scale. "
260+
"Use step='prepare'/'convert' for static quantization, "
261+
"or set use_dynamic_per_tensor_scale=True."
262+
)
263+
252264

253265
@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig)
254266
def _nvfp4_inference_linear_transform(
@@ -269,6 +281,15 @@ def _nvfp4_inference_linear_transform(
269281
raise RuntimeError(
270282
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
271283
)
284+
if (
285+
config.quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER
286+
and weight.shape[-1] % 64 != 0
287+
):
288+
warnings.warn(
289+
f"Skipping NVFP4 quantization for layer with K={weight.shape[-1]}: "
290+
f"flashinfer requires K to be divisible by 64."
291+
)
292+
return module
272293

273294
step = config.step
274295
if step == QuantizationStep.PREPARE or step == "prepare":

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torchao.quantization.quantize_.common import (
3434
QuantizeTensorKwargs,
3535
)
36-
from torchao.utils import TorchAOBaseTensor, fill_defaults
36+
from torchao.utils import TorchAOBaseTensor, _is_flashinfer_available, fill_defaults
3737

3838
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny
3939

@@ -98,7 +98,7 @@ class NVFP4Tensor(TorchAOBaseTensor):
9898
block_size (int): Block size for quantization (fixed at 16)
9999
orig_dtype (torch.dtype): Original tensor dtype before quantization
100100
is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format
101-
quantize_to_nvfp4_kernel_choice (QuantizeToNVFP4KernelChoice): Kernel preference for quantization
101+
quantize_to_nvfp4_kernel_choice (QuantizeToNVFP4KernelChoice): Kernel choice for quantization
102102
"""
103103

104104
tensor_data_names = ["qdata", "scale"]
@@ -179,7 +179,7 @@ def to_nvfp4(
179179
act_per_tensor_scale: Optional pre-computed absolute maximum for calibration for activation
180180
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
181181
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
182-
quantize_to_nvfp4_kernel_choice: Kernel preference for quantization
182+
quantize_to_nvfp4_kernel_choice: Kernel choice for quantization
183183
act_quant_kwargs: If specified, config for quantizing the activation
184184
185185
Returns:
@@ -201,6 +201,31 @@ def to_nvfp4(
201201
"Triton kernel requires per_tensor_scale"
202202
)
203203
blockwise_scales, data_lp = mslk_quantize_nvfp4(data_hp, per_tensor_scale)
204+
elif quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.FLASHINFER:
205+
from flashinfer import SfLayout
206+
from flashinfer import nvfp4_quantize as flashinfer_nvfp4_quantize
207+
208+
assert _is_flashinfer_available(), (
209+
"flashinfer is not available, please install flashinfer-python, apache-tvm-ffi, and nvidia-ml-py to use FLASHINFER kernel choice"
210+
)
211+
assert per_tensor_scale is not None, (
212+
"flashinfer nvfp4_quantize requires per_tensor_scale"
213+
)
214+
assert is_swizzled_scales, (
215+
"flashinfer nvfp4_quantize only supports swizzled scales"
216+
)
217+
assert K % 64 == 0, (
218+
f"flashinfer nvfp4_quantize requires K (dim -1) to be divisible by 64, got {K}"
219+
)
220+
# flashinfer uses global_sf = (F8E4M3_MAX * F4_E2M1_MAX) / amax
221+
# which is 1 / per_tensor_scale
222+
global_sf = 1.0 / per_tensor_scale
223+
data_lp, blockwise_scales = flashinfer_nvfp4_quantize(
224+
data_hp,
225+
global_sf,
226+
sfLayout=SfLayout.layout_128x4,
227+
do_shuffle=False,
228+
)
204229
elif quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.TORCH:
205230
blockwise_scales, data_lp = nvfp4_quantize(
206231
data_hp, block_size, per_tensor_scale

0 commit comments

Comments
 (0)