diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 594f1151ad..74cf8397da 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -6,7 +6,7 @@ import importlib import logging -from typing import Optional, Tuple +from typing import Tuple import numpy as np import torch @@ -425,226 +425,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: return out - @triton.jit - def convert_fp32_to_fp4_packed(x_pairs): - """Convert FP32 pairs to packed FP4 format. - - This function takes tensor where consecutive values along the last dimension - are packed together into single bytes. - - Args: - x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains - interleaved pairs of FP32 values to be packed together. - - Returns: - Packed tensor with shape [...] (last dimension removed) where each - element is an int8 containing 2 FP4 values: - - First value of pair → low nibble (bits 0-3) - - Second value of pair → high nibble (bits 4-7) - - Example: - Input: [128, 32, 2] containing FP32 pairs - Output: [128, 32] containing packed FP4 bytes - - """ - - x_fp4x2 = tl.inline_asm_elementwise( - asm=""" - { - .reg .b8 byte0, byte1, byte2, byte3; - cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1; - cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2; - cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3; - cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4; - mov.b32 $0, {byte0, byte1, byte2, byte3}; - } - """, - constraints=("=r,r,r,r,r,r,r,r,r"), - args=x_pairs, - dtype=tl.uint8, - is_pure=True, - pack=4, - ) - - return x_fp4x2 - - # Sauce: https://github.com/gau-nernst/quantized-training - @triton.jit - def quantize_nvfp4_triton_kernel( - x_ptr, - tensor_scale_ptr, - q_ptr, - s_ptr, - stride_xm, - stride_xn, - M, - N, - USE_TENSOR_SCALE: tl.constexpr, - MASK_SCALES: tl.constexpr, - ): - F4_E2M1_MAX = 6.0 - F8E4M3_MAX = 448.0 - E4M3_EPS = 1.5258789e-05 - - pid_m = tl.program_id(1) - pid_n = tl.program_id(0) - - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] - if MASK_SCALES: - mask = (offs_m < M) & (offs_n < N) - other = 0.0 - else: - mask = None - other = None - x = tl.load( - x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other - ) # [128, 64] - x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] - - # Compute block-wise scales - block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] - - if USE_TENSOR_SCALE: - # Two-level scaling: quantize block scales with per-tensor scale - tensor_scale = tl.load(tensor_scale_ptr) - - # First compute block scales - block_scale_f32 = (block_amax / F4_E2M1_MAX).to(tl.float32) - - # Quantize the block scales with per-tensor scale - scaled_block_scales = block_scale_f32 / tensor_scale - scaled_block_scales = tl.clamp(scaled_block_scales, E4M3_EPS, F8E4M3_MAX) - scales = scaled_block_scales.to(tl.float8e4nv) - - # Apply combined scale to data: per_tensor_scale * quantized_block_scale - total_scale = tensor_scale * scales.to(tl.float32)[:, :, None] - x_blocks = tl.div_rn(x_blocks, total_scale) - else: - # Single-level scaling: use block scales directly - scales_f32 = block_amax / F4_E2M1_MAX - scales_f32 = tl.clamp(scales_f32, E4M3_EPS, F8E4M3_MAX) - scales = scales_f32.to(tl.float8e4nv) - - # Apply block scale to data - total_scale = scales.to(tl.float32)[:, :, None] - x_blocks = tl.div_rn(x_blocks, total_scale) - - # NVIDIA layout for scales - if MASK_SCALES: - # Create offsets for the scale dimensions (4 blocks per row) - scale_offs_n = pid_n * 4 + tl.arange(0, 4)[None, :] - - # Mask out scales to 0 if we are not aligned to 128 x 64 - scales = tl.where( - (offs_m < M) & (scale_offs_n < N // 16), - scales, - 0.0, - ) - packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16) - offs_m = tl.arange(0, 32)[:, None] - offs_n = tl.arange(0, 16)[None, :] - tl.store( - s_ptr - + (pid_m * tl.num_programs(0) + pid_n) * (32 * 16) - + offs_m * 16 - + offs_n, - packed_scales, - ) - - # Convert to FP4 - x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] - if MASK_SCALES: - mask = (offs_m < M) & (offs_n < N // 2) - else: - mask = None - tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask) - - @torch.library.custom_op("ao::triton_quantize_nvfp4", mutates_args=()) - def triton_quantize_nvfp4( - x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Quantize a tensor to NVFP4 format. - - Args: - x (torch.Tensor): Input tensor to be quantized. - tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization. - If None, uses single-level block-wise quantization only. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. - - Note: - Since VLLM does not use dyanmo guards we need to make this a custom op - to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` - """ - # reshape to 2d - orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1] - x = x.reshape(-1, orig_N) - - M, N = x.shape - # assert M % 128 == 0 and N % 64 == 0 - assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" - - # Calculate blocks needed - num_scales = N // 16 - n_row_blocks = triton.cdiv(M, 128) - n_col_blocks = triton.cdiv(num_scales, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - # mask out scales to 0 if we are not aligned to 128 x 64 - MASK_SCALES = M % 128 != 0 or N % 64 != 0 - - xq = x.new_empty(M, N // 2, dtype=torch.uint8) - scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) - - grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) - - if per_tensor_scale is None: - # Don't allocate tensor, we just steal this since it won't be used in kernel - tensor_scale_ptr = x - use_tensor_scale = False - else: - tensor_scale_ptr = per_tensor_scale - use_tensor_scale = True - - quantize_nvfp4_triton_kernel[grid]( - x, - tensor_scale_ptr, - xq, - scales, - x.stride(0), - x.stride(1), - M, - N, - USE_TENSOR_SCALE=use_tensor_scale, - MASK_SCALES=MASK_SCALES, - ) - - # reshape back to original shape - scales = scales.view(*orig_leading_dims, -1, padded_cols) - xq = xq.view(*orig_leading_dims, -1, N // 2) - - return scales, xq.view(torch.uint8) - - @triton_quantize_nvfp4.register_fake - def _(x, per_tensor_scale=None): - M, N = x.shape - num_scales = N // 16 - n_row_blocks = triton.cdiv(M, 128) - n_col_blocks = triton.cdiv(num_scales, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - scales = torch.empty( - padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn - ) - xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8) - return scales, xq - @triton_mx_block_rearrange.register_fake def _(scale_tensor): rows, cols = scale_tensor.shape @@ -666,11 +446,6 @@ def triton_to_mxfp8_dim1_reference( def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor: raise AssertionError("needs torch version 2.8+ and triton") - def triton_quantize_nvfp4( - x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - raise AssertionError("needs torch version 2.8+ and triton") - def triton_mxfp8_dequant_dim0( e4m3_data: torch.Tensor, e8m0_scales: torch.Tensor,