99import torch
1010import torch .nn .functional as F
1111
12+ from torchao .prototype .mx_formats .config import QuantizeToNVFP4KernelChoice
1213from torchao .prototype .mx_formats .constants import (
1314 F4_E2M1_MAX ,
1415)
@@ -383,14 +384,14 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
383384 x .clone (),
384385 per_tensor_scale = per_tensor_scale ,
385386 is_swizzled_scales = True ,
386- use_triton_kernel = False ,
387+ quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice . TORCH ,
387388 )
388389
389390 nvfp4_triton = NVFP4Tensor .to_nvfp4 (
390391 x .clone (),
391392 per_tensor_scale = per_tensor_scale ,
392393 is_swizzled_scales = True ,
393- use_triton_kernel = True ,
394+ quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice . MSLK ,
394395 )
395396
396397 torch .testing .assert_close (nvfp4_pt .scale .flatten (), nvfp4_triton .scale .flatten ())
@@ -427,7 +428,10 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
427428@pytest .mark .parametrize ("compile" , [False ])
428429@pytest .mark .parametrize ("bias" , [True , False ])
429430@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
430- @pytest .mark .parametrize ("use_triton_kernel" , [True , False ])
431+ @pytest .mark .parametrize (
432+ "quantize_to_nvfp4_kernel_choice" ,
433+ [QuantizeToNVFP4KernelChoice .MSLK , QuantizeToNVFP4KernelChoice .TORCH ],
434+ )
431435@pytest .mark .parametrize (
432436 "shapes" ,
433437 [
@@ -452,7 +456,7 @@ def test_nvfp4_matmul_with_amax(
452456 compile : bool ,
453457 bias : bool ,
454458 inpt_dtype : torch .dtype ,
455- use_triton_kernel : bool ,
459+ quantize_to_nvfp4_kernel_choice : QuantizeToNVFP4KernelChoice ,
456460 shapes : tuple ,
457461):
458462 # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
@@ -489,13 +493,13 @@ def test_nvfp4_matmul_with_amax(
489493 A ,
490494 per_tensor_scale = a_scale ,
491495 is_swizzled_scales = True ,
492- use_triton_kernel = use_triton_kernel ,
496+ quantize_to_nvfp4_kernel_choice = quantize_to_nvfp4_kernel_choice ,
493497 )
494498 B_nvfp4 = NVFP4Tensor .to_nvfp4 (
495499 B ,
496500 per_tensor_scale = b_scale ,
497501 is_swizzled_scales = True ,
498- use_triton_kernel = use_triton_kernel ,
502+ quantize_to_nvfp4_kernel_choice = quantize_to_nvfp4_kernel_choice ,
499503 act_quant_kwargs = act_quant_kwargs ,
500504 )
501505
@@ -527,7 +531,7 @@ def test_nvfp4_to_copy():
527531 assert x .act_per_tensor_scale is None
528532 assert y .act_per_tensor_scale is None
529533 assert x .block_size == y .block_size
530- assert x .use_triton_kernel == y .use_triton_kernel
534+ assert x .quantize_to_nvfp4_kernel_choice == y .quantize_to_nvfp4_kernel_choice
531535 assert x .act_quant_kwargs == y .act_quant_kwargs
532536 assert x .dtype == torch .float32
533537 assert y .dtype == torch .bfloat16
@@ -538,7 +542,10 @@ def test_nvfp4_to_copy():
538542 not torch_version_at_least ("2.8.0" ), reason = "NVFP4 requires PyTorch 2.8+"
539543)
540544@pytest .mark .parametrize ("transpose" , [False , True ])
541- @pytest .mark .parametrize ("use_triton_kernel" , [False , True ])
545+ @pytest .mark .parametrize (
546+ "quantize_to_nvfp4_kernel_choice" ,
547+ [QuantizeToNVFP4KernelChoice .TORCH , QuantizeToNVFP4KernelChoice .MSLK ],
548+ )
542549@pytest .mark .parametrize ("is_swizzled_scales" , [False , True ])
543550@pytest .mark .parametrize (
544551 "shape" ,
@@ -551,11 +558,17 @@ def test_nvfp4_to_copy():
551558 ),
552559)
553560def test_scale_shape_matches_qdata (
554- transpose , use_triton_kernel , is_swizzled_scales , shape
561+ transpose , quantize_to_nvfp4_kernel_choice , is_swizzled_scales , shape
555562):
556- if use_triton_kernel and not is_sm_at_least_100 ():
563+ if (
564+ quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .MSLK
565+ and not is_sm_at_least_100 ()
566+ ):
557567 pytest .skip ("CUDA capability >= 10.0 required for nvfp4 triton kernel" )
558- if use_triton_kernel and not is_swizzled_scales :
568+ if (
569+ quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .MSLK
570+ and not is_swizzled_scales
571+ ):
559572 pytest .skip ("triton kernel requires swizzled scales" )
560573
561574 block_size = 16
@@ -568,7 +581,7 @@ def test_scale_shape_matches_qdata(
568581 x_hp ,
569582 per_tensor_scale = per_tensor_scale ,
570583 is_swizzled_scales = is_swizzled_scales ,
571- use_triton_kernel = use_triton_kernel ,
584+ quantize_to_nvfp4_kernel_choice = quantize_to_nvfp4_kernel_choice ,
572585 )
573586
574587 if len (shape ) == 2 :
0 commit comments