Skip to content

Commit 1ea5dfb

Browse files
committed
Add offset to RS quantization
1 parent c6b10ff commit 1ea5dfb

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

torchao/prototype/mx_formats/kernels.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def quantize_nvfp4_triton_kernel(
529529
M,
530530
N,
531531
seed_ptr,
532+
offset_ptr,
532533
USE_TENSOR_SCALE: tl.constexpr,
533534
MASK_SCALES: tl.constexpr,
534535
ROUNDING_MODE: tl.constexpr, # 0=RN, 1=RS
@@ -616,6 +617,9 @@ def quantize_nvfp4_triton_kernel(
616617
else:
617618
# Stochastic rounding (RS) via hardware cvt.rs.satfinite.e2m1x4.f32
618619
seed = tl.load(seed_ptr)
620+
offset_base = tl.load(offset_ptr)
621+
# Place local index in upper 32 bits and offset_base in lower 32 bits
622+
offset = (tl.cast(out_offs, tl.int64) << 32) | offset_base
619623
rbits = tl.randint(seed, out_offs)
620624
x_fp4x2 = convert_fp32_to_fp4_packed_rs(x_pairs, rbits)
621625
if MASK_SCALES:
@@ -629,7 +633,8 @@ def triton_quantize_nvfp4(
629633
x: torch.Tensor,
630634
per_tensor_scale: Optional[torch.Tensor] = None,
631635
rounding_mode: int = 0,
632-
seed: Optional[torch.Tensor] = None,
636+
seed: torch.Tensor | None = None,
637+
offset: torch.Tensor | None = None,
633638
) -> Tuple[torch.Tensor, torch.Tensor]:
634639
"""Quantize a tensor to NVFP4 format.
635640
@@ -638,8 +643,12 @@ def triton_quantize_nvfp4(
638643
per_tensor_scale (Optional[torch.Tensor]): Per-tensor scale for two-level quantization.
639644
If None, uses single-level block-wise quantization only.
640645
rounding_mode (int): 0 for round-to-nearest, 1 for stochastic rounding.
641-
seed (Optional[torch.Tensor]): Seed tensor for stochastic rounding RNG.
642-
Should be a single-element int32 tensor on the same device as x.
646+
seed (torch.Tensor | None): Seed tensor for stochastic rounding RNG.
647+
Should be a single-element int64 tensor on the same device as x.
648+
When None, stochastic rounding uses a dummy seed (caller should
649+
only pass None when rounding_mode=0).
650+
offset (torch.Tensor | None): Seed tensor for stochastic rounding RNG.
651+
Should be a single-element int64 tensor on the same device as x.
643652
When None, stochastic rounding uses a dummy seed (caller should
644653
only pass None when rounding_mode=0).
645654
@@ -689,6 +698,7 @@ def triton_quantize_nvfp4(
689698
# For seed_ptr: if seed is None (RN mode), reuse x as dummy pointer
690699
# (kernel won't read it when ROUNDING_MODE=0)
691700
seed_ptr = seed if seed is not None else x
701+
offset_ptr = offset if offset is not None else x
692702

693703
quantize_nvfp4_triton_kernel[grid](
694704
x,
@@ -700,6 +710,7 @@ def triton_quantize_nvfp4(
700710
M,
701711
N,
702712
seed_ptr,
713+
offset_ptr,
703714
USE_TENSOR_SCALE=use_tensor_scale,
704715
MASK_SCALES=MASK_SCALES,
705716
ROUNDING_MODE=rounding_mode,
@@ -712,7 +723,7 @@ def triton_quantize_nvfp4(
712723
return scales, xq.view(torch.uint8)
713724

714725
@triton_quantize_nvfp4.register_fake
715-
def _(x, per_tensor_scale=None, rounding_mode=RoundingMode.RN, seed=None):
726+
def _(x, per_tensor_scale=None, rounding_mode=RoundingMode.RN, seed=None, offset=None):
716727
M, N = x.shape
717728
num_scales = N // 16
718729
n_row_blocks = triton.cdiv(M, 128)

0 commit comments

Comments
 (0)