Skip to content

Commit 52a83c8

Browse files
committed
Refactor use_triton_kernel to use quantize_kernel_preference
Summary: This is to prefer the addition of flashinfer quantize kernel path in next PR Test Plan: python test/prototype/mx_formats/test_inference_workflow.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 47457cf Pull Request resolved: #3911
1 parent 5e79e5e commit 52a83c8

10 files changed

Lines changed: 196 additions & 61 deletions

File tree

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn as nn
1313
from torch.profiler import ProfilerActivity, profile
1414

15+
from torchao.prototype.mx_formats.config import QuantizeToNVFP4KernelChoice
1516
from torchao.prototype.mx_formats.inference_workflow import (
1617
MXDynamicActivationMXWeightConfig,
1718
NVFP4DynamicActivationNVFP4WeightConfig,
@@ -140,7 +141,10 @@ def test_inference_workflow_mx(
140141
@pytest.mark.parametrize("compile", [True, False])
141142
@pytest.mark.parametrize("quant_type", ["dynamic", "weight_only"])
142143
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
143-
@pytest.mark.parametrize("use_triton_kernel", [True, False])
144+
@pytest.mark.parametrize(
145+
"quantize_to_nvfp4_kernel_choice",
146+
[QuantizeToNVFP4KernelChoice.TORCH, QuantizeToNVFP4KernelChoice.MSLK],
147+
)
144148
@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False])
145149
@pytest.mark.parametrize(
146150
"shapes",
@@ -164,7 +168,7 @@ def test_inference_workflow_nvfp4(
164168
compile: bool,
165169
quant_type: str,
166170
inpt_dtype: torch.dtype,
167-
use_triton_kernel: bool,
171+
quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice,
168172
use_dynamic_per_tensor_scale: bool,
169173
shapes: tuple,
170174
use_inference_mode: bool,
@@ -179,17 +183,24 @@ def test_inference_workflow_nvfp4(
179183
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")
180184
if quant_type == "weight_only" and compile:
181185
pytest.skip("TODO: weight_only quant currently errors w/ compile")
182-
if quant_type == "weight_only" and use_triton_kernel:
186+
if (
187+
quant_type == "weight_only"
188+
and quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
189+
):
183190
pytest.skip("unsupported configuration")
184-
if use_triton_kernel and not use_dynamic_per_tensor_scale:
191+
if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK and not use_dynamic_per_tensor_scale:
185192
pytest.skip("unsupported configuration")
186193

187194
if use_inference_mode and (
188-
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
195+
shapes != (128, 64, 256)
196+
or inpt_dtype != torch.bfloat16
197+
or quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
189198
):
190199
pytest.skip("skipping unnecessary tests for inference mode")
191200
if x_rank == 3 and (
192-
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
201+
shapes != (128, 64, 256)
202+
or inpt_dtype != torch.bfloat16
203+
or quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
193204
):
194205
pytest.skip("skipping unnecessary tests for x_rank 3")
195206

@@ -200,7 +211,7 @@ def test_inference_workflow_nvfp4(
200211

201212
if quant_type == "dynamic":
202213
config = NVFP4DynamicActivationNVFP4WeightConfig(
203-
use_triton_kernel=use_triton_kernel,
214+
quantize_to_nvfp4_kernel_choice=quantize_to_nvfp4_kernel_choice,
204215
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
205216
)
206217
else:
@@ -218,7 +229,10 @@ def test_inference_workflow_nvfp4(
218229

219230
y_ref = m(x)
220231

221-
if use_triton_kernel and quant_type == "dynamic":
232+
if (
233+
quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
234+
and quant_type == "dynamic"
235+
):
222236
with cuda_kernel_profiler("triton_quantize_nvfp4_kernel") as result:
223237
y_mx = m_mx(x)
224238
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
@@ -393,7 +407,7 @@ def test_nvfp4_static_vs_dynamic_quantization():
393407
quantize_(
394408
m_dynamic,
395409
NVFP4DynamicActivationNVFP4WeightConfig(
396-
use_triton_kernel=False,
410+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
397411
use_dynamic_per_tensor_scale=True,
398412
),
399413
)
@@ -406,7 +420,7 @@ def test_nvfp4_static_vs_dynamic_quantization():
406420
m_static,
407421
NVFP4DynamicActivationNVFP4WeightConfig(
408422
step="prepare",
409-
use_triton_kernel=False,
423+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
410424
),
411425
)
412426
# Calibrate with the same input used for testing
@@ -416,7 +430,7 @@ def test_nvfp4_static_vs_dynamic_quantization():
416430
m_static,
417431
NVFP4DynamicActivationNVFP4WeightConfig(
418432
step="convert",
419-
use_triton_kernel=False,
433+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
420434
),
421435
)
422436

test/prototype/mx_formats/test_mx_serialization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.nn as nn
1414

15+
from torchao.prototype.mx_formats.config import QuantizeToNVFP4KernelChoice
1516
from torchao.prototype.mx_formats.inference_workflow import (
1617
MXDynamicActivationMXWeightConfig,
1718
NVFP4DynamicActivationNVFP4WeightConfig,
@@ -48,7 +49,7 @@ def test_serialization(recipe_name):
4849
else:
4950
assert recipe_name == "nvfp4", "unsupported"
5051
config = NVFP4DynamicActivationNVFP4WeightConfig(
51-
use_triton_kernel=False,
52+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
5253
use_dynamic_per_tensor_scale=False,
5354
)
5455

test/prototype/mx_formats/test_nvfp4_tensor.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.nn.functional as F
1111

12+
from torchao.prototype.mx_formats.config import QuantizeToNVFP4KernelChoice
1213
from torchao.prototype.mx_formats.constants import (
1314
F4_E2M1_MAX,
1415
)
@@ -383,14 +384,14 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
383384
x.clone(),
384385
per_tensor_scale=per_tensor_scale,
385386
is_swizzled_scales=True,
386-
use_triton_kernel=False,
387+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH,
387388
)
388389

389390
nvfp4_triton = NVFP4Tensor.to_nvfp4(
390391
x.clone(),
391392
per_tensor_scale=per_tensor_scale,
392393
is_swizzled_scales=True,
393-
use_triton_kernel=True,
394+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.MSLK,
394395
)
395396

396397
torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
@@ -427,7 +428,10 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
427428
@pytest.mark.parametrize("compile", [False])
428429
@pytest.mark.parametrize("bias", [True, False])
429430
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
430-
@pytest.mark.parametrize("use_triton_kernel", [True, False])
431+
@pytest.mark.parametrize(
432+
"quantize_to_nvfp4_kernel_choice",
433+
[QuantizeToNVFP4KernelChoice.MSLK, QuantizeToNVFP4KernelChoice.TORCH],
434+
)
431435
@pytest.mark.parametrize(
432436
"shapes",
433437
[
@@ -452,7 +456,7 @@ def test_nvfp4_matmul_with_amax(
452456
compile: bool,
453457
bias: bool,
454458
inpt_dtype: torch.dtype,
455-
use_triton_kernel: bool,
459+
quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice,
456460
shapes: tuple,
457461
):
458462
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
@@ -489,13 +493,13 @@ def test_nvfp4_matmul_with_amax(
489493
A,
490494
per_tensor_scale=a_scale,
491495
is_swizzled_scales=True,
492-
use_triton_kernel=use_triton_kernel,
496+
quantize_to_nvfp4_kernel_choice=quantize_to_nvfp4_kernel_choice,
493497
)
494498
B_nvfp4 = NVFP4Tensor.to_nvfp4(
495499
B,
496500
per_tensor_scale=b_scale,
497501
is_swizzled_scales=True,
498-
use_triton_kernel=use_triton_kernel,
502+
quantize_to_nvfp4_kernel_choice=quantize_to_nvfp4_kernel_choice,
499503
act_quant_kwargs=act_quant_kwargs,
500504
)
501505

@@ -527,7 +531,7 @@ def test_nvfp4_to_copy():
527531
assert x.act_per_tensor_scale is None
528532
assert y.act_per_tensor_scale is None
529533
assert x.block_size == y.block_size
530-
assert x.use_triton_kernel == y.use_triton_kernel
534+
assert x.quantize_to_nvfp4_kernel_choice == y.quantize_to_nvfp4_kernel_choice
531535
assert x.act_quant_kwargs == y.act_quant_kwargs
532536
assert x.dtype == torch.float32
533537
assert y.dtype == torch.bfloat16
@@ -538,7 +542,10 @@ def test_nvfp4_to_copy():
538542
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
539543
)
540544
@pytest.mark.parametrize("transpose", [False, True])
541-
@pytest.mark.parametrize("use_triton_kernel", [False, True])
545+
@pytest.mark.parametrize(
546+
"quantize_to_nvfp4_kernel_choice",
547+
[QuantizeToNVFP4KernelChoice.TORCH, QuantizeToNVFP4KernelChoice.MSLK],
548+
)
542549
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
543550
@pytest.mark.parametrize(
544551
"shape",
@@ -551,11 +558,17 @@ def test_nvfp4_to_copy():
551558
),
552559
)
553560
def test_scale_shape_matches_qdata(
554-
transpose, use_triton_kernel, is_swizzled_scales, shape
561+
transpose, quantize_to_nvfp4_kernel_choice, is_swizzled_scales, shape
555562
):
556-
if use_triton_kernel and not is_sm_at_least_100():
563+
if (
564+
quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
565+
and not is_sm_at_least_100()
566+
):
557567
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
558-
if use_triton_kernel and not is_swizzled_scales:
568+
if (
569+
quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice.MSLK
570+
and not is_swizzled_scales
571+
):
559572
pytest.skip("triton kernel requires swizzled scales")
560573

561574
block_size = 16
@@ -568,7 +581,7 @@ def test_scale_shape_matches_qdata(
568581
x_hp,
569582
per_tensor_scale=per_tensor_scale,
570583
is_swizzled_scales=is_swizzled_scales,
571-
use_triton_kernel=use_triton_kernel,
584+
quantize_to_nvfp4_kernel_choice=quantize_to_nvfp4_kernel_choice,
572585
)
573586

574587
if len(shape) == 2:

torchao/prototype/mx_formats/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torchao.prototype.mx_formats.config import (
2+
QuantizeToNVFP4KernelChoice,
23
ScaleCalculationMode,
34
)
45

@@ -15,5 +16,6 @@
1516
"MXDynamicActivationMXWeightConfig",
1617
"NVFP4DynamicActivationNVFP4WeightConfig",
1718
"NVFP4ObservedLinear",
19+
"QuantizeToNVFP4KernelChoice",
1820
"NVFP4WeightOnlyConfig",
1921
]

torchao/prototype/mx_formats/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ class MXFP8Dim1CastKernelChoice(Enum):
3232
TORCH = "torch"
3333

3434

35+
class QuantizeToNVFP4KernelChoice(str, Enum):
36+
"""Enum for specifying the kernel used for quantizing a high precision
37+
tensor (float32/bfloat16/float16) to nvfp4 tensor with blockwise quantization
38+
"""
39+
40+
TORCH = "torch"
41+
"""Use torch native high precision to nvfp4 quantize kernel implemented with torch ops"""
42+
43+
MSLK = "mslk"
44+
"""Use MSLK triton high precision to nvfp4 quantize kernel"""
45+
46+
47+
torch.serialization.add_safe_globals([QuantizeToNVFP4KernelChoice])
48+
3549
# register as pytree constant so we can use dynamo nonstrict trace in torchao.prototype.moe_training.ep
3650
@register_as_pytree_constant
3751
class ScaleCalculationMode(Enum):

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from torch import Tensor
1515

1616
from torchao.core.config import AOBaseConfig
17-
from torchao.prototype.mx_formats.config import _validate_elem_dtype
17+
from torchao.prototype.mx_formats.config import (
18+
QuantizeToNVFP4KernelChoice,
19+
_validate_elem_dtype,
20+
)
1821
from torchao.prototype.mx_formats.mx_tensor import (
1922
MXTensor,
2023
QuantizeTensorToMXKwargs,
@@ -23,6 +26,7 @@
2326
from torchao.prototype.mx_formats.nvfp4_tensor import (
2427
NVFP4Tensor,
2528
QuantizeTensorToNVFP4Kwargs,
29+
_handle_use_triton_kernel,
2630
per_tensor_amax_to_scale,
2731
)
2832
from torchao.quantization.quant_api import _module_extra_repr, _quantization_type
@@ -204,7 +208,7 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
204208
set to False.
205209
206210
Configuration parameters:
207-
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True).
211+
- quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice, kernel preference for quantization (default: QuantizeToNVFP4KernelChoice.MSLK)
208212
Requires `MSLK <https://github.com/pytorch/MSLK>`__ to be installed.
209213
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
210214
- step: Optional[QuantizationStep], the quantization step for observer-based flow
@@ -221,11 +225,18 @@ class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig):
221225
:language: python
222226
"""
223227

224-
use_triton_kernel: bool = True
228+
quantize_to_nvfp4_kernel_choice: QuantizeToNVFP4KernelChoice = (
229+
QuantizeToNVFP4KernelChoice.MSLK
230+
)
225231
use_dynamic_per_tensor_scale: bool = True
226232
step: Optional["QuantizationStep"] = None
233+
use_triton_kernel: bool = True
227234

228235
def __post_init__(self):
236+
self.quantize_to_nvfp4_kernel_choice = _handle_use_triton_kernel(
237+
self.use_triton_kernel, self.quantize_to_nvfp4_kernel_choice
238+
)
239+
229240
if isinstance(self.step, str):
230241
self.step = QuantizationStep(self.step)
231242
# Validate PyTorch version
@@ -277,7 +288,7 @@ def _nvfp4_inference_linear_transform(
277288

278289
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
279290
use_dynamic_per_tensor_scale=False,
280-
use_triton_kernel=config.use_triton_kernel,
291+
quantize_to_nvfp4_kernel_choice=config.quantize_to_nvfp4_kernel_choice,
281292
is_swizzled_scales=True,
282293
)
283294

@@ -286,10 +297,12 @@ def _nvfp4_inference_linear_transform(
286297
per_tensor_scale=weight_per_tensor_scale,
287298
act_per_tensor_scale=act_per_tensor_scale.detach(),
288299
is_swizzled_scales=True,
289-
use_triton_kernel=False, # Always use traditional construction for weights
300+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH, # Always use traditional construction for weights
290301
act_quant_kwargs=act_quant_kwargs,
291302
)
292-
quantized_weight.use_triton_kernel = config.use_triton_kernel
303+
quantized_weight.quantize_to_nvfp4_kernel_choice = (
304+
config.quantize_to_nvfp4_kernel_choice
305+
)
293306

294307
# Create new Linear (not observed) with quantized weight
295308
linear = torch.nn.Linear(
@@ -319,18 +332,20 @@ def _nvfp4_inference_linear_transform(
319332

320333
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
321334
use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale,
322-
use_triton_kernel=config.use_triton_kernel,
335+
quantize_to_nvfp4_kernel_choice=config.quantize_to_nvfp4_kernel_choice,
323336
is_swizzled_scales=True,
324337
)
325338

326339
quantized_weight = NVFP4Tensor.to_nvfp4(
327340
weight,
328341
per_tensor_scale=per_tensor_scale,
329342
is_swizzled_scales=True,
330-
use_triton_kernel=False, # Always use traditional construction for weights
343+
quantize_to_nvfp4_kernel_choice=QuantizeToNVFP4KernelChoice.TORCH, # Always use traditional construction for weights
331344
act_quant_kwargs=act_quant_kwargs,
332345
)
333-
quantized_weight.use_triton_kernel = config.use_triton_kernel
346+
quantized_weight.quantize_to_nvfp4_kernel_choice = (
347+
config.quantize_to_nvfp4_kernel_choice
348+
)
334349
setattr(
335350
module,
336351
parameter_name,

0 commit comments

Comments
 (0)