Skip to content

Commit 6d6cc54

Browse files
committed
Update on "Add support for flashinfer quantize kernel option for nvfp4"
Summary: Added the flashinfer option for better performance on some of the workflow we are interested in, also added numerical equivalence test between different nvfp4_quantize_kernel_choice options Test Plan: pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k test_kernel_preference_numerical_equivalence We'll test speedup a bit later Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents f8cff6d + b1d2a60 commit 6d6cc54

2 files changed

Lines changed: 23 additions & 28 deletions

File tree

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -394,16 +394,6 @@ def test_quantize_to_nvfp4_kernel_numerical_equivalence(
394394

395395
other_kernel_choices = [QuantizeToNVFP4KernelChoice.MSLK]
396396

397-
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
398-
pt_unpacked = unpack_uint4(nvfp4_pt.qdata.view(torch.uint8))
399-
triton_unpacked = unpack_uint4(nvfp4_triton.qdata.view(torch.uint8))
400-
torch.testing.assert_close(
401-
pt_unpacked,
402-
triton_unpacked,
403-
atol=0,
404-
rtol=0,
405-
)
406-
407397
# Flashinfer requires the library and per_tensor_scale
408398
if _is_flashinfer_available() and use_per_tensor_scale:
409399
other_kernel_choices.append(QuantizeToNVFP4KernelChoice.FLASHINFER)

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,30 @@ def _handle_use_triton_kernel(
4646
) -> QuantizeToNVFP4KernelChoice:
4747
"""Handle deprecated use_triton_kernel parameter.
4848
49-
Raises an exception if use_triton_kernel does not match
50-
quantize_to_nvfp4_kernel_choice.
49+
Raises ValueError if use_triton_kernel does not match
50+
quantize_to_nvfp4_kernel_choice. use_triton_kernel=True corresponds to
51+
MSLK, use_triton_kernel=False corresponds to TORCH or FLASHINFER.
5152
"""
52-
expected = (
53-
QuantizeToNVFP4KernelChoice.MSLK
54-
if use_triton_kernel
55-
else QuantizeToNVFP4KernelChoice.TORCH
56-
)
57-
if expected != quantize_to_nvfp4_kernel_choice:
58-
raise ValueError(
59-
f"`use_triton_kernel={use_triton_kernel}` does not match "
60-
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
61-
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
62-
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
63-
"`use_triton_kernel=True` is equivalent to "
64-
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK`, "
65-
"`use_triton_kernel=False` is equivalent to "
66-
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH`."
67-
)
53+
if use_triton_kernel:
54+
if quantize_to_nvfp4_kernel_choice != QuantizeToNVFP4KernelChoice.MSLK:
55+
raise ValueError(
56+
f"`use_triton_kernel=True` does not match "
57+
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
58+
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
59+
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
60+
"`use_triton_kernel=True` is equivalent to "
61+
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK`."
62+
)
63+
else:
64+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK:
65+
raise ValueError(
66+
f"`use_triton_kernel=False` does not match "
67+
f"`quantize_to_nvfp4_kernel_choice={quantize_to_nvfp4_kernel_choice}`. "
68+
"`use_triton_kernel` is deprecated and will be removed after 0.17. "
69+
"Please use `quantize_to_nvfp4_kernel_choice` instead. "
70+
"`use_triton_kernel=False` is equivalent to "
71+
"`quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH`."
72+
)
6873
return quantize_to_nvfp4_kernel_choice
6974

7075

0 commit comments

Comments
 (0)