Skip to content

Commit fd2a42e

Browse files
committed
delete old torchao to_nvfp4 triton kernel
Summary: we moved to `mslk` in #4031, deleting the old kernel Separate PR to keep PRs small Test Plan: ``` pytest test/prototype/mx_formats -s ``` ghstack-source-id: c6ee2bb ghstack-comment-id: 4057398637 Pull-Request: #4078
1 parent 769416b commit fd2a42e

1 file changed

Lines changed: 1 addition & 226 deletions

File tree

torchao/prototype/mx_formats/kernels.py

Lines changed: 1 addition & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import importlib
88
import logging
9-
from typing import Optional, Tuple
9+
from typing import Tuple
1010

1111
import numpy as np
1212
import torch
@@ -425,226 +425,6 @@ def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
425425

426426
return out
427427

428-
@triton.jit
429-
def convert_fp32_to_fp4_packed(x_pairs):
430-
"""Convert FP32 pairs to packed FP4 format.
431-
432-
This function takes tensor where consecutive values along the last dimension
433-
are packed together into single bytes.
434-
435-
Args:
436-
x_pairs: [Tensor, Tensor] both w/ shapes [..., 1] where zipped last dimension contains
437-
interleaved pairs of FP32 values to be packed together.
438-
439-
Returns:
440-
Packed tensor with shape [...] (last dimension removed) where each
441-
element is an int8 containing 2 FP4 values:
442-
- First value of pair → low nibble (bits 0-3)
443-
- Second value of pair → high nibble (bits 4-7)
444-
445-
Example:
446-
Input: [128, 32, 2] containing FP32 pairs
447-
Output: [128, 32] containing packed FP4 bytes
448-
449-
"""
450-
451-
x_fp4x2 = tl.inline_asm_elementwise(
452-
asm="""
453-
{
454-
.reg .b8 byte0, byte1, byte2, byte3;
455-
cvt.rn.satfinite.e2m1x2.f32 byte0, $5, $1;
456-
cvt.rn.satfinite.e2m1x2.f32 byte1, $6, $2;
457-
cvt.rn.satfinite.e2m1x2.f32 byte2, $7, $3;
458-
cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $4;
459-
mov.b32 $0, {byte0, byte1, byte2, byte3};
460-
}
461-
""",
462-
constraints=("=r,r,r,r,r,r,r,r,r"),
463-
args=x_pairs,
464-
dtype=tl.uint8,
465-
is_pure=True,
466-
pack=4,
467-
)
468-
469-
return x_fp4x2
470-
471-
# Sauce: https://github.com/gau-nernst/quantized-training
472-
@triton.jit
473-
def quantize_nvfp4_triton_kernel(
474-
x_ptr,
475-
tensor_scale_ptr,
476-
q_ptr,
477-
s_ptr,
478-
stride_xm,
479-
stride_xn,
480-
M,
481-
N,
482-
USE_TENSOR_SCALE: tl.constexpr,
483-
MASK_SCALES: tl.constexpr,
484-
):
485-
F4_E2M1_MAX = 6.0
486-
F8E4M3_MAX = 448.0
487-
E4M3_EPS = 1.5258789e-05
488-
489-
pid_m = tl.program_id(1)
490-
pid_n = tl.program_id(0)
491-
492-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
493-
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
494-
if MASK_SCALES:
495-
mask = (offs_m < M) & (offs_n < N)
496-
other = 0.0
497-
else:
498-
mask = None
499-
other = None
500-
x = tl.load(
501-
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
502-
) # [128, 64]
503-
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
504-
505-
# Compute block-wise scales
506-
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]
507-
508-
if USE_TENSOR_SCALE:
509-
# Two-level scaling: quantize block scales with per-tensor scale
510-
tensor_scale = tl.load(tensor_scale_ptr)
511-
512-
# First compute block scales
513-
block_scale_f32 = (block_amax / F4_E2M1_MAX).to(tl.float32)
514-
515-
# Quantize the block scales with per-tensor scale
516-
scaled_block_scales = block_scale_f32 / tensor_scale
517-
scaled_block_scales = tl.clamp(scaled_block_scales, E4M3_EPS, F8E4M3_MAX)
518-
scales = scaled_block_scales.to(tl.float8e4nv)
519-
520-
# Apply combined scale to data: per_tensor_scale * quantized_block_scale
521-
total_scale = tensor_scale * scales.to(tl.float32)[:, :, None]
522-
x_blocks = tl.div_rn(x_blocks, total_scale)
523-
else:
524-
# Single-level scaling: use block scales directly
525-
scales_f32 = block_amax / F4_E2M1_MAX
526-
scales_f32 = tl.clamp(scales_f32, E4M3_EPS, F8E4M3_MAX)
527-
scales = scales_f32.to(tl.float8e4nv)
528-
529-
# Apply block scale to data
530-
total_scale = scales.to(tl.float32)[:, :, None]
531-
x_blocks = tl.div_rn(x_blocks, total_scale)
532-
533-
# NVIDIA layout for scales
534-
if MASK_SCALES:
535-
# Create offsets for the scale dimensions (4 blocks per row)
536-
scale_offs_n = pid_n * 4 + tl.arange(0, 4)[None, :]
537-
538-
# Mask out scales to 0 if we are not aligned to 128 x 64
539-
scales = tl.where(
540-
(offs_m < M) & (scale_offs_n < N // 16),
541-
scales,
542-
0.0,
543-
)
544-
packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
545-
offs_m = tl.arange(0, 32)[:, None]
546-
offs_n = tl.arange(0, 16)[None, :]
547-
tl.store(
548-
s_ptr
549-
+ (pid_m * tl.num_programs(0) + pid_n) * (32 * 16)
550-
+ offs_m * 16
551-
+ offs_n,
552-
packed_scales,
553-
)
554-
555-
# Convert to FP4
556-
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
557-
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
558-
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
559-
if MASK_SCALES:
560-
mask = (offs_m < M) & (offs_n < N // 2)
561-
else:
562-
mask = None
563-
tl.store(q_ptr + offs_m * (N // 2) + offs_n, x_fp4x2, mask=mask)
564-
565-
@torch.library.custom_op("ao::triton_quantize_nvfp4", mutates_args=())
566-
def triton_quantize_nvfp4(
567-
x: torch.Tensor, per_tensor_scale: Optional[torch.Tensor] = None
568-
) -> Tuple[torch.Tensor, torch.Tensor]:
569-
"""Quantize a tensor to NVFP4 format.
570-
571-
Args:
572-
x (torch.Tensor): Input tensor to be quantized.
573-
tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization.
574-
If None, uses single-level block-wise quantization only.
575-
576-
Returns:
577-
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.
578-
579-
Note:
580-
Since VLLM does not use dyanmo guards we need to make this a custom op
581-
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
582-
"""
583-
# reshape to 2d
584-
orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1]
585-
x = x.reshape(-1, orig_N)
586-
587-
M, N = x.shape
588-
# assert M % 128 == 0 and N % 64 == 0
589-
assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization"
590-
591-
# Calculate blocks needed
592-
num_scales = N // 16
593-
n_row_blocks = triton.cdiv(M, 128)
594-
n_col_blocks = triton.cdiv(num_scales, 4)
595-
padded_rows = n_row_blocks * 128
596-
padded_cols = n_col_blocks * 4
597-
598-
# mask out scales to 0 if we are not aligned to 128 x 64
599-
MASK_SCALES = M % 128 != 0 or N % 64 != 0
600-
601-
xq = x.new_empty(M, N // 2, dtype=torch.uint8)
602-
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
603-
604-
grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
605-
606-
if per_tensor_scale is None:
607-
# Don't allocate tensor, we just steal this since it won't be used in kernel
608-
tensor_scale_ptr = x
609-
use_tensor_scale = False
610-
else:
611-
tensor_scale_ptr = per_tensor_scale
612-
use_tensor_scale = True
613-
614-
quantize_nvfp4_triton_kernel[grid](
615-
x,
616-
tensor_scale_ptr,
617-
xq,
618-
scales,
619-
x.stride(0),
620-
x.stride(1),
621-
M,
622-
N,
623-
USE_TENSOR_SCALE=use_tensor_scale,
624-
MASK_SCALES=MASK_SCALES,
625-
)
626-
627-
# reshape back to original shape
628-
scales = scales.view(*orig_leading_dims, -1, padded_cols)
629-
xq = xq.view(*orig_leading_dims, -1, N // 2)
630-
631-
return scales, xq.view(torch.uint8)
632-
633-
@triton_quantize_nvfp4.register_fake
634-
def _(x, per_tensor_scale=None):
635-
M, N = x.shape
636-
num_scales = N // 16
637-
n_row_blocks = triton.cdiv(M, 128)
638-
n_col_blocks = triton.cdiv(num_scales, 4)
639-
padded_rows = n_row_blocks * 128
640-
padded_cols = n_col_blocks * 4
641-
642-
scales = torch.empty(
643-
padded_rows, padded_cols, device=x.device, dtype=torch.float8_e4m3fn
644-
)
645-
xq = torch.empty(M, N // 2, device=x.device, dtype=torch.uint8)
646-
return scales, xq
647-
648428
@triton_mx_block_rearrange.register_fake
649429
def _(scale_tensor):
650430
rows, cols = scale_tensor.shape
@@ -666,11 +446,6 @@ def triton_to_mxfp8_dim1_reference(
666446
def triton_mx_block_rearrange(scale_tensor: torch.Tensor) -> torch.Tensor:
667447
raise AssertionError("needs torch version 2.8+ and triton")
668448

669-
def triton_quantize_nvfp4(
670-
x: torch.Tensor, tensor_scale: Optional[torch.Tensor] = None
671-
) -> Tuple[torch.Tensor, torch.Tensor]:
672-
raise AssertionError("needs torch version 2.8+ and triton")
673-
674449
def triton_mxfp8_dequant_dim0(
675450
e4m3_data: torch.Tensor,
676451
e8m0_scales: torch.Tensor,

0 commit comments

Comments
 (0)