2323from torchao .quantization .utils import compute_error
2424from torchao .testing .utils import skip_if_rocm
2525from torchao .utils import (
26+ _is_flashinfer_available ,
2627 is_sm_at_least_100 ,
2728 torch_version_at_least ,
2829)
@@ -368,8 +369,10 @@ def test_nvfp4_swizzled_scales_get_scales_method():
368369 not is_sm_at_least_100 (), reason = "requires sm100+ for raw intrinsics"
369370)
370371@torch .no_grad ()
371- def test_triton_nvfp4_quantize_equivalence (M , N , use_per_tensor_scale , dtype ):
372- """Test that Triton and PyTorch NVFP4 quantization produce equivalent results."""
372+ def test_quantize_to_nvfp4_kernel_numerical_equivalence (
373+ M , N , use_per_tensor_scale , dtype
374+ ):
375+ """Test that different quantize_to_nvfp4 kernel choices produce numerically equivalent results."""
373376 if not use_per_tensor_scale :
374377 pytest .skip ("MSLK triton kernel requires per_tensor_scale" )
375378
@@ -380,19 +383,16 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
380383 if use_per_tensor_scale :
381384 per_tensor_scale = per_tensor_amax_to_scale (torch .amax (torch .abs (x )))
382385
383- nvfp4_pt = NVFP4Tensor .to_nvfp4 (
386+ # Reference: TORCH kernel choice
387+ nvfp4_ref = NVFP4Tensor .to_nvfp4 (
384388 x .clone (),
385389 per_tensor_scale = per_tensor_scale ,
386390 is_swizzled_scales = True ,
387391 quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice .TORCH ,
388392 )
393+ ref_dequant = nvfp4_ref .dequantize (dtype )
389394
390- nvfp4_triton = NVFP4Tensor .to_nvfp4 (
391- x .clone (),
392- per_tensor_scale = per_tensor_scale ,
393- is_swizzled_scales = True ,
394- quantize_to_nvfp4_kernel_choice = QuantizeToNVFP4KernelChoice .MSLK ,
395- )
395+ other_kernel_choices = [QuantizeToNVFP4KernelChoice .MSLK ]
396396
397397 torch .testing .assert_close (nvfp4_pt .scale .flatten (), nvfp4_triton .scale .flatten ())
398398 pt_unpacked = unpack_uint4 (nvfp4_pt .qdata .view (torch .uint8 ))
@@ -404,16 +404,44 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
404404 rtol = 0 ,
405405 )
406406
407- x_pt_dequant = nvfp4_pt .dequantize (dtype )
408- x_triton_dequant = nvfp4_triton .dequantize (dtype )
409-
410- sqnr = compute_error (x_pt_dequant , x_triton_dequant )
411- SQNR_THRESHOLD = 40.0
412-
413- assert sqnr >= SQNR_THRESHOLD , (
414- f"SQNR { sqnr :.2f} < { SQNR_THRESHOLD } for M={ M } , N={ N } , "
415- f"use_per_tensor_scale={ use_per_tensor_scale } , dtype={ dtype } "
416- )
407+ # Flashinfer requires the library and per_tensor_scale
408+ if _is_flashinfer_available () and use_per_tensor_scale :
409+ other_kernel_choices .append (QuantizeToNVFP4KernelChoice .FLASHINFER )
410+
411+ SQNR_THRESHOLD = 28.0
412+ for kc in other_kernel_choices :
413+ nvfp4_other = NVFP4Tensor .to_nvfp4 (
414+ x .clone (),
415+ per_tensor_scale = per_tensor_scale ,
416+ is_swizzled_scales = True ,
417+ quantize_to_nvfp4_kernel_choice = kc ,
418+ )
419+
420+ # For kernel choices that use the same quantization algorithm as TORCH
421+ # (MSLK should be bitwise identical), verify internal data matches exactly
422+ if kc == QuantizeToNVFP4KernelChoice .MSLK :
423+ torch .testing .assert_close (
424+ nvfp4_ref .scale .flatten (),
425+ nvfp4_other .scale .flatten (),
426+ atol = 0 ,
427+ rtol = 0 ,
428+ )
429+ ref_unpacked = unpack_uint4 (nvfp4_ref .qdata )
430+ other_unpacked = unpack_uint4 (nvfp4_other .qdata )
431+ torch .testing .assert_close (
432+ ref_unpacked ,
433+ other_unpacked ,
434+ atol = 0 ,
435+ rtol = 0 ,
436+ )
437+
438+ # Verify dequantized values are numerically close for all kernel choices
439+ other_dequant = nvfp4_other .dequantize (dtype )
440+ sqnr = compute_error (ref_dequant , other_dequant )
441+ assert sqnr >= SQNR_THRESHOLD , (
442+ f"SQNR { sqnr :.2f} < { SQNR_THRESHOLD } between TORCH and { kc } , "
443+ f"M={ M } , N={ N } , use_per_tensor_scale={ use_per_tensor_scale } , dtype={ dtype } "
444+ )
417445
418446
419447@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
@@ -430,7 +458,11 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
430458@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
431459@pytest .mark .parametrize (
432460 "quantize_to_nvfp4_kernel_choice" ,
433- [QuantizeToNVFP4KernelChoice .MSLK , QuantizeToNVFP4KernelChoice .TORCH ],
461+ [
462+ QuantizeToNVFP4KernelChoice .MSLK ,
463+ QuantizeToNVFP4KernelChoice .FLASHINFER ,
464+ QuantizeToNVFP4KernelChoice .TORCH ,
465+ ],
434466)
435467@pytest .mark .parametrize (
436468 "shapes" ,
@@ -469,6 +501,10 @@ def test_nvfp4_matmul_with_amax(
469501 if quant_type == "weight_only" and compile :
470502 pytest .skip ("TODO: weight_only currently errors w/ compile" )
471503
504+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
505+ if not _is_flashinfer_available ():
506+ pytest .skip ("flashinfer not available" )
507+
472508 m , k , n = shapes
473509
474510 # Create activation tensor
@@ -544,7 +580,11 @@ def test_nvfp4_to_copy():
544580@pytest .mark .parametrize ("transpose" , [False , True ])
545581@pytest .mark .parametrize (
546582 "quantize_to_nvfp4_kernel_choice" ,
547- [QuantizeToNVFP4KernelChoice .TORCH , QuantizeToNVFP4KernelChoice .MSLK ],
583+ [
584+ QuantizeToNVFP4KernelChoice .TORCH ,
585+ QuantizeToNVFP4KernelChoice .MSLK ,
586+ QuantizeToNVFP4KernelChoice .FLASHINFER ,
587+ ],
548588)
549589@pytest .mark .parametrize ("is_swizzled_scales" , [False , True ])
550590@pytest .mark .parametrize (
@@ -570,11 +610,22 @@ def test_scale_shape_matches_qdata(
570610 and not is_swizzled_scales
571611 ):
572612 pytest .skip ("triton kernel requires swizzled scales" )
613+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
614+ if not _is_flashinfer_available ():
615+ pytest .skip ("flashinfer not available" )
616+ if not is_swizzled_scales :
617+ pytest .skip ("flashinfer requires swizzled scales" )
618+ if shape [- 1 ] % 64 != 0 :
619+ pytest .skip ("flashinfer requires K to be divisible by 64" )
573620
574621 block_size = 16
575622
576623 x_hp = torch .randn (* shape , device = "cuda" )
577624
625+ if quantize_to_nvfp4_kernel_choice == QuantizeToNVFP4KernelChoice .FLASHINFER :
626+ # flashinfer only supports fp16/bf16/e4m3 input
627+ x_hp = x_hp .to (torch .bfloat16 )
628+
578629 per_tensor_scale = per_tensor_amax_to_scale (torch .amax (torch .abs (x_hp )))
579630
580631 x = NVFP4Tensor .to_nvfp4 (
0 commit comments