66
77import importlib
88import logging
9- from typing import Optional , Tuple
9+ from typing import Tuple
1010
1111import numpy as np
1212import torch
@@ -425,226 +425,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
425425
426426 return out
427427
428- @triton .jit
429- def convert_fp32_to_fp4_packed (x_pairs ):
430- """Convert FP32 pairs to packed FP4 format.
431-
432- This function takes tensor where consecutive values along the last dimension
433- are packed together into single bytes.
434-
435- Args:
436- x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains
437- interleaved pairs of FP32 values to be packed together.
438-
439- Returns:
440- Packed tensor with shape [...] (last dimension removed) where each
441- element is an int8 containing 2 FP4 values:
442- - First value of pair → low nibble (bits 0-3)
443- - Second value of pair → high nibble (bits 4-7)
444-
445- Example:
446- Input: [128, 32, 2] containing FP32 pairs
447- Output: [128, 32] containing packed FP4 bytes
448-
449- """
450-
451- x_fp4x2 = tl .inline_asm_elementwise (
452- asm = """
453- {
454- .reg .b8 byte0, byte1, byte2, byte3;
455- cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1;
456- cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2;
457- cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3;
458- cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4;
459- mov.b32 $0, {byte0, byte1, byte2, byte3};
460- }
461- """ ,
462- constraints = ("=r,r,r,r,r,r,r,r,r" ),
463- args = x_pairs ,
464- dtype = tl .uint8 ,
465- is_pure = True ,
466- pack = 4 ,
467- )
468-
469- return x_fp4x2
470-
471- # Sauce: https://github.com/gau-nernst/quantized-training
472- @triton .jit
473- def quantize_nvfp4_triton_kernel (
474- x_ptr ,
475- tensor_scale_ptr ,
476- q_ptr ,
477- s_ptr ,
478- stride_xm ,
479- stride_xn ,
480- M ,
481- N ,
482- USE_TENSOR_SCALE : tl .constexpr ,
483- MASK_SCALES : tl .constexpr ,
484- ):
485- F4_E2M1_MAX = 6.0
486- F8E4M3_MAX = 448.0
487- E4M3_EPS = 1.5258789e-05
488-
489- pid_m = tl .program_id (1 )
490- pid_n = tl .program_id (0 )
491-
492- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
493- offs_n = pid_n * 64 + tl .arange (0 , 64 )[None , :]
494- if MASK_SCALES :
495- mask = (offs_m < M ) & (offs_n < N )
496- other = 0.0
497- else :
498- mask = None
499- other = None
500- x = tl .load (
501- x_ptr + offs_m * stride_xm + offs_n * stride_xn , mask = mask , other = other
502- ) # [128, 64]
503- x_blocks = x .to (tl .float32 ).reshape (128 , 4 , 16 ) # [128, 4, 16]
504-
505- # Compute block-wise scales
506- block_amax = tl .max (x_blocks .abs (), axis = 2 ) # [128, 4]
507-
508- if USE_TENSOR_SCALE :
509- # Two-level scaling: quantize block scales with per-tensor scale
510- tensor_scale = tl .load (tensor_scale_ptr )
511-
512- # First compute block scales
513- block_scale_f32 = (block_amax / F4_E2M1_MAX ).to (tl .float32 )
514-
515- # Quantize the block scales with per-tensor scale
516- scaled_block_scales = block_scale_f32 / tensor_scale
517- scaled_block_scales = tl .clamp (scaled_block_scales , E4M3_EPS , F8E4M3_MAX )
518- scales = scaled_block_scales .to (tl .float8e4nv )
519-
520- # Apply combined scale to data: per_tensor_scale * quantized_block_scale
521- total_scale = tensor_scale * scales .to (tl .float32 )[:, :, None ]
522- x_blocks = tl .div_rn (x_blocks , total_scale )
523- else :
524- # Single-level scaling: use block scales directly
525- scales_f32 = block_amax / F4_E2M1_MAX
526- scales_f32 = tl .clamp (scales_f32 , E4M3_EPS , F8E4M3_MAX )
527- scales = scales_f32 .to (tl .float8e4nv )
528-
529- # Apply block scale to data
530- total_scale = scales .to (tl .float32 )[:, :, None ]
531- x_blocks = tl .div_rn (x_blocks , total_scale )
532-
533- # NVIDIA layout for scales
534- if MASK_SCALES :
535- # Create offsets for the scale dimensions (4 blocks per row)
536- scale_offs_n = pid_n * 4 + tl .arange (0 , 4 )[None , :]
537-
538- # Mask out scales to 0 if we are not aligned to 128 x 64
539- scales = tl .where (
540- (offs_m < M ) & (scale_offs_n < N // 16 ),
541- scales ,
542- 0.0 ,
543- )
544- packed_scales = scales .reshape (4 , 32 , 4 ).permute (1 , 0 , 2 ).reshape (32 , 16 )
545- offs_m = tl .arange (0 , 32 )[:, None ]
546- offs_n = tl .arange (0 , 16 )[None , :]
547- tl .store (
548- s_ptr
549- + (pid_m * tl .num_programs (0 ) + pid_n ) * (32 * 16 )
550- + offs_m * 16
551- + offs_n ,
552- packed_scales ,
553- )
554-
555- # Convert to FP4
556- x_fp4x2 = convert_fp32_to_fp4_packed (x_blocks .reshape (128 , 32 , 2 ).split ())
557- offs_m = pid_m * 128 + tl .arange (0 , 128 )[:, None ]
558- offs_n = pid_n * 32 + tl .arange (0 , 32 )[None , :]
559- if MASK_SCALES :
560- mask = (offs_m < M ) & (offs_n < N // 2 )
561- else :
562- mask = None
563- tl .store (q_ptr + offs_m * (N // 2 ) + offs_n , x_fp4x2 , mask = mask )
564-
565- @torch .library .custom_op ("ao::triton_quantize_nvfp4" , mutates_args = ())
566- def triton_quantize_nvfp4 (
567- x : torch .Tensor , per_tensor_scale : Optional [torch .Tensor ] = None
568- ) -> Tuple [torch .Tensor , torch .Tensor ]:
569- """Quantize a tensor to NVFP4 format.
570-
571- Args:
572- x (torch.Tensor): Input tensor to be quantized.
573- tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization.
574- If None, uses single-level block-wise quantization only.
575-
576- Returns:
577- Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
578-
579- Note:
580- Since VLLM does not use dyanmo guards we need to make this a custom op
581- to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
582- """
583- # reshape to 2d
584- orig_leading_dims , _orig_M , orig_N = x .shape [:- 2 ], x .shape [- 2 ], x .shape [- 1 ]
585- x = x .reshape (- 1 , orig_N )
586-
587- M , N = x .shape
588- # assert M % 128 == 0 and N % 64 == 0
589- assert N % 16 == 0 , "N must be divisible by 16 for NVFP4 quantization"
590-
591- # Calculate blocks needed
592- num_scales = N // 16
593- n_row_blocks = triton .cdiv (M , 128 )
594- n_col_blocks = triton .cdiv (num_scales , 4 )
595- padded_rows = n_row_blocks * 128
596- padded_cols = n_col_blocks * 4
597-
598- # mask out scales to 0 if we are not aligned to 128 x 64
599- MASK_SCALES = M % 128 != 0 or N % 64 != 0
600-
601- xq = x .new_empty (M , N // 2 , dtype = torch .uint8 )
602- scales = x .new_empty (padded_rows , padded_cols , dtype = torch .float8_e4m3fn )
603-
604- grid = (triton .cdiv (N , 64 ), triton .cdiv (M , 128 ))
605-
606- if per_tensor_scale is None :
607- # Don't allocate tensor, we just steal this since it won't be used in kernel
608- tensor_scale_ptr = x
609- use_tensor_scale = False
610- else :
611- tensor_scale_ptr = per_tensor_scale
612- use_tensor_scale = True
613-
614- quantize_nvfp4_triton_kernel [grid ](
615- x ,
616- tensor_scale_ptr ,
617- xq ,
618- scales ,
619- x .stride (0 ),
620- x .stride (1 ),
621- M ,
622- N ,
623- USE_TENSOR_SCALE = use_tensor_scale ,
624- MASK_SCALES = MASK_SCALES ,
625- )
626-
627- # reshape back to original shape
628- scales = scales .view (* orig_leading_dims , - 1 , padded_cols )
629- xq = xq .view (* orig_leading_dims , - 1 , N // 2 )
630-
631- return scales , xq .view (torch .uint8 )
632-
633- @triton_quantize_nvfp4 .register_fake
634- def _ (x , per_tensor_scale = None ):
635- M , N = x .shape
636- num_scales = N // 16
637- n_row_blocks = triton .cdiv (M , 128 )
638- n_col_blocks = triton .cdiv (num_scales , 4 )
639- padded_rows = n_row_blocks * 128
640- padded_cols = n_col_blocks * 4
641-
642- scales = torch .empty (
643- padded_rows , padded_cols , device = x .device , dtype = torch .float8_e4m3fn
644- )
645- xq = torch .empty (M , N // 2 , device = x .device , dtype = torch .uint8 )
646- return scales , xq
647-
648428 @triton_mx_block_rearrange .register_fake
649429 def _ (scale_tensor ):
650430 rows , cols = scale_tensor .shape
@@ -666,11 +446,6 @@ def triton_to_mxfp8_dim1_reference(
666446 def triton_mx_block_rearrange (scale_tensor : torch .Tensor ) -> torch .Tensor :
667447 raise AssertionError ("needs torch version 2.8+ and triton" )
668448
669- def triton_quantize_nvfp4 (
670- x : torch .Tensor , tensor_scale : Optional [torch .Tensor ] = None
671- ) -> Tuple [torch .Tensor , torch .Tensor ]:
672- raise AssertionError ("needs torch version 2.8+ and triton" )
673-
674449 def triton_mxfp8_dequant_dim0 (
675450 e4m3_data : torch .Tensor ,
676451 e8m0_scales : torch .Tensor ,
0 commit comments