Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 1 addition & 226 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import importlib
import logging
from typing import Optional, Tuple
from typing import Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -425,226 +425,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:

return out

@triton.jit
def convert_fp32_to_fp4_packed(x_pairs):
"""Convert FP32 pairs to packed FP4 format.

This function takes tensor where consecutive values along the last dimension
are packed together into single bytes.

Args:
x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains
interleaved pairs of FP32 values to be packed together.

Returns:
Packed tensor with shape [...] (last dimension removed) where each
element is an int8 containing 2 FP4 values:
- First value of pair → low nibble (bits 0-3)
- Second value of pair → high nibble (bits 4-7)

Example:
Input: [128, 32, 2] containing FP32 pairs
Output: [128, 32] containing packed FP4 bytes

"""

x_fp4x2 = tl.inline_asm_elementwise(
asm="""
{
.reg .b8 byte0, byte1, byte2, byte3;
cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1;
cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2;
cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3;
cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4;
mov.b32 $0, {byte0, byte1, byte2, byte3};
}
""",
constraints=("=r,r,r,r,r,r,r,r,r"),
args=x_pairs,
dtype=tl.uint8,
is_pure=True,
pack=4,
)

return x_fp4x2

# Sauce: https://github.com/gau-nernst/quantized-training
@triton.jit
def quantize_nvfp4_triton_kernel(
x_ptr,
tensor_scale_ptr,
q_ptr,
s_ptr,
stride_xm,
stride_xn,
M,
N,
USE_TENSOR_SCALE: tl.constexpr,
MASK_SCALES: tl.constexpr,
):
F4_E2M1_MAX = 6.0
F8E4M3_MAX = 448.0
E4M3_EPS = 1.5258789e-05

pid_m = tl.program_id(1)
pid_n = tl.program_id(0)

offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N)
other = 0.0
else:
mask = None
other = None
x = tl.load(
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
) # [128, 64]
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]

# Compute block-wise scales
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]

if USE_TENSOR_SCALE:
# Two-level scaling: quantize block scales with per-tensor scale
tensor_scale = tl.load(tensor_scale_ptr)

# First compute block scales
block_scale_f32 = (block_amax / F4_E2M1_MAX).to(tl.float32)

# Quantize the block scales with per-tensor scale
scaled_block_scales = block_scale_f32 / tensor_scale
scaled_block_scales = tl.clamp(scaled_block_scales, E4M3_EPS, F8E4M3_MAX)
scales = scaled_block_scales.to(tl.float8e4nv)

# Apply combined scale to data: per_tensor_scale * quantized_block_scale
total_scale = tensor_scale * scales.to(tl.float32)[:, :, None]
x_blocks = tl.div_rn(x_blocks, total_scale)
else:
# Single-level scaling: use block scales directly
scales_f32 = block_amax / F4_E2M1_MAX
scales_f32 = tl.clamp(scales_f32, E4M3_EPS, F8E4M3_MAX)
scales = scales_f32.to(tl.float8e4nv)

# Apply block scale to data
total_scale = scales.to(tl.float32)[:, :, None]
x_blocks = tl.div_rn(x_blocks, total_scale)

# NVIDIA layout for scales
if MASK_SCALES:
# Create offsets for the scale dimensions (4 blocks per row)
scale_offs_n = pid_n * 4 + tl.arange(0, 4)[None, :]

# Mask out scales to 0 if we are not aligned to 128 x 64
scales = tl.where(
(offs_m < M) & (scale_offs_n < N // 16),
scales,
0.0,
)
packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
offs_m = tl.arange(0, 32)[:, None]
offs_n = tl.arange(0, 16)[None, :]
tl.store(
s_ptr
+ (pid_m * tl.num_programs(0) + pid_n) * (32 * 16)
+ offs_m * 16
+ offs_n,
packed_scales,
)

# Convert to FP4
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N // 2)
else:
mask = None
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)

@torch.library.custom_op("ao::triton_quantize_nvfp4", mutates_args=())
def triton_quantize_nvfp4(
x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Quantize a tensor to NVFP4 format.

Args:
x (torch.Tensor): Input tensor to be quantized.
tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization.
If None, uses single-level block-wise quantization only.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.

Note:
Since VLLM does not use dyanmo guards we need to make this a custom op
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
"""
# reshape to 2d
orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1]
x = x.reshape(-1, orig_N)

M, N = x.shape
# assert M % 128 == 0 and N % 64 == 0
assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization"

# Calculate blocks needed
num_scales = N // 16
n_row_blocks = triton.cdiv(M, 128)
n_col_blocks = triton.cdiv(num_scales, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

# mask out scales to 0 if we are not aligned to 128 x 64
MASK_SCALES = M % 128 != 0 or N % 64 != 0

xq = x.new_empty(M, N // 2, dtype=torch.uint8)
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)

grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))

if per_tensor_scale is None:
# Don't allocate tensor, we just steal this since it won't be used in kernel
tensor_scale_ptr = x
use_tensor_scale = False
else:
tensor_scale_ptr = per_tensor_scale
use_tensor_scale = True

quantize_nvfp4_triton_kernel[grid](
x,
tensor_scale_ptr,
xq,
scales,
x.stride(0),
x.stride(1),
M,
N,
USE_TENSOR_SCALE=use_tensor_scale,
MASK_SCALES=MASK_SCALES,
)

# reshape back to original shape
scales = scales.view(*orig_leading_dims, -1, padded_cols)
xq = xq.view(*orig_leading_dims, -1, N // 2)

return scales, xq.view(torch.uint8)

@triton_quantize_nvfp4.register_fake
def _(x, per_tensor_scale=None):
M, N = x.shape
num_scales = N // 16
n_row_blocks = triton.cdiv(M, 128)
n_col_blocks = triton.cdiv(num_scales, 4)
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4

scales = torch.empty(
padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn
)
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
return scales, xq

@triton_mx_block_rearrange.register_fake
def _(scale_tensor):
rows, cols = scale_tensor.shape
Expand All @@ -666,11 +446,6 @@ def triton_to_mxfp8_dim1_reference(
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_quantize_nvfp4(
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
raise AssertionError("needs torch version 2.8+ and triton")

def triton_mxfp8_dequant_dim0(
e4m3_data: torch.Tensor,
e8m0_scales: torch.Tensor,
Expand Down
Loading