2323from torchao .quantization .utils import compute_error
2424from torchao .testing .utils import skip_if_rocm
2525from torchao .utils import (
26+ _is_flashinfer_available ,
2627 is_sm_at_least_100 ,
2728 torch_version_at_least ,
2829)
@@ -368,8 +369,10 @@ def test_nvfp4_swizzled_scales_get_scales_method():
368369 not is_sm_at_least_100 (), reason = "requires sm100+ for raw intrinsics"
369370)
370371@torch .no_grad ()
371- def test_triton_nvfp4_quantize_equivalence (M , N , use_per_tensor_scale , dtype ):
372- """Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+ def test_quantize_to_nvfp4_kernel_numerical_equivalence (
373+ M , N , use_per_tensor_scale , dtype
374+ ):
375+ """Test that different quantize_to_nvfp4 kernel choices produce numerically equivalent results."""
373376 if not use_per_tensor_scale :
374377 pytest .skip ("MSLK triton kernel requires per_tensor_scale" )
375378
@@ -380,40 +383,55 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
380383 if use_per_tensor_scale :
381384 per_tensor_scale = per_tensor_amax_to_scale (torch .amax (torch .abs (x )))
382385
383- nvfp4_pt = NVFP4Tensor .to_nvfp4 (
386+ # Reference: TORCH kernel choice
387+ nvfp4_ref = NVFP4Tensor .to_nvfp4 (
384388 x .clone (),
385389 per_tensor_scale = per_tensor_scale ,
386390 is_swizzled_scales = True ,
387391 quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice .TORCH ,
388392 )
389-
390- nvfp4_triton = NVFP4Tensor .to_nvfp4 (
391- x .clone (),
392- per_tensor_scale = per_tensor_scale ,
393- is_swizzled_scales = True ,
394- quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice .MSLK ,
395- )
396-
397- torch .testing .assert_close (nvfp4_pt .scale .flatten (), nvfp4_triton .scale .flatten ())
398- pt_unpacked = unpack_uint4 (nvfp4_pt .qdata .view (torch .uint8 ))
399- triton_unpacked = unpack_uint4 (nvfp4_triton .qdata .view (torch .uint8 ))
400- torch .testing .assert_close (
401- pt_unpacked ,
402- triton_unpacked ,
403- atol = 0 ,
404- rtol = 0 ,
405- )
406-
407- x_pt_dequant = nvfp4_pt .dequantize (dtype )
408- x_triton_dequant = nvfp4_triton .dequantize (dtype )
409-
410- sqnr = compute_error (x_pt_dequant , x_triton_dequant )
411- SQNR_THRESHOLD = 40.0
412-
413- assert sqnr >= SQNR_THRESHOLD , (
414- f"SQNR { sqnr :.2f} < { SQNR_THRESHOLD } for M={ M } , N={ N } , "
415- f"use_per_tensor_scale={ use_per_tensor_scale } , dtype={ dtype } "
416- )
393+ ref_dequant = nvfp4_ref .dequantize (dtype )
394+
395+ other_kernel_choices = [QuantizeToNVFP4KernelChoice .MSLK ]
396+
397+ # Flashinfer requires the library and per_tensor_scale
398+ if _is_flashinfer_available () and use_per_tensor_scale :
399+ other_kernel_choices .append (QuantizeToNVFP4KernelChoice .FLASHINFER )
400+
401+ SQNR_THRESHOLD = 28.0
402+ for kc in other_kernel_choices :
403+ nvfp4_other = NVFP4Tensor .to_nvfp4 (
404+ x .clone (),
405+ per_tensor_scale = per_tensor_scale ,
406+ is_swizzled_scales = True ,
407+ quantize_to_nvfp4_kernel_choice = kc ,
408+ )
409+
410+ # For kernel choices that use the same quantization algorithm as TORCH
411+ # (MSLK should be bitwise identical), verify internal data matches exactly
412+ if kc == QuantizeToNVFP4KernelChoice .MSLK :
413+ torch .testing .assert_close (
414+ nvfp4_ref .scale .flatten (),
415+ nvfp4_other .scale .flatten (),
416+ atol = 0 ,
417+ rtol = 0 ,
418+ )
419+ ref_unpacked = unpack_uint4 (nvfp4_ref .qdata )
420+ other_unpacked = unpack_uint4 (nvfp4_other .qdata )
421+ torch .testing .assert_close (
422+ ref_unpacked ,
423+ other_unpacked ,
424+ atol = 0 ,
425+ rtol = 0 ,
426+ )
427+
428+ # Verify dequantized values are numerically close for all kernel choices
429+ other_dequant = nvfp4_other .dequantize (dtype )
430+ sqnr = compute_error (ref_dequant , other_dequant )
431+ assert sqnr >= SQNR_THRESHOLD , (
432+ f"SQNR { sqnr :.2f} < { SQNR_THRESHOLD } between TORCH and { kc } , "
433+ f"M={ M } , N={ N } , use_per_tensor_scale={ use_per_tensor_scale } , dtype={ dtype } "
434+ )
417435
418436
419437@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -430,7 +448,11 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
430448@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
431449@pytest .mark .parametrize (
432450 "quantize_to_nvfp4_kernel_choice" ,
433- [QuantizeToNVFP4KernelChoice .MSLK , QuantizeToNVFP4KernelChoice .TORCH ],
451+ [
452+ QuantizeToNVFP4KernelChoice .MSLK ,
453+ QuantizeToNVFP4KernelChoice .FLASHINFER ,
454+ QuantizeToNVFP4KernelChoice .TORCH ,
455+ ],
434456)
435457@pytest .mark .parametrize (
436458 "shapes" ,
@@ -469,6 +491,10 @@ def test_nvfp4_matmul_with_amax(
469491 if quant_type == "weight_only" and compile :
470492 pytest .skip ("TODO: weight_only currently errors w/ compile" )
471493
494+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
495+ if not _is_flashinfer_available ():
496+ pytest .skip ("flashinfer not available" )
497+
472498 m , k , n = shapes
473499
474500 # Create activation tensor
@@ -544,7 +570,11 @@ def test_nvfp4_to_copy():
544570@pytest .mark .parametrize ("transpose" , [False , True ])
545571@pytest .mark .parametrize (
546572 "quantize_to_nvfp4_kernel_choice" ,
547- [QuantizeToNVFP4KernelChoice .TORCH , QuantizeToNVFP4KernelChoice .MSLK ],
573+ [
574+ QuantizeToNVFP4KernelChoice .TORCH ,
575+ QuantizeToNVFP4KernelChoice .MSLK ,
576+ QuantizeToNVFP4KernelChoice .FLASHINFER ,
577+ ],
548578)
549579@pytest .mark .parametrize ("is_swizzled_scales" , [False , True ])
550580@pytest .mark .parametrize (
@@ -570,11 +600,22 @@ def test_scale_shape_matches_qdata(
570600 and not is_swizzled_scales
571601 ):
572602 pytest .skip ("triton kernel requires swizzled scales" )
603+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
604+ if not _is_flashinfer_available ():
605+ pytest .skip ("flashinfer not available" )
606+ if not is_swizzled_scales :
607+ pytest .skip ("flashinfer requires swizzled scales" )
608+ if shape [- 1 ] % 64 != 0 :
609+ pytest .skip ("flashinfer requires K to be divisible by 64" )
573610
574611 block_size = 16
575612
576613 x_hp = torch .randn (* shape , device = "cuda" )
577614
615+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
616+ # flashinfer only supports fp16/bf16/e4m3 input
617+ x_hp = x_hp .to (torch .bfloat16 )
618+
578619 per_tensor_scale = per_tensor_amax_to_scale (torch .amax (torch .abs (x_hp )))
579620
580621 x = NVFP4Tensor .to_nvfp4 (
0 commit comments