Skip to content

Commit d8792dd

Browse files
committed
Implements rounding mode for NVFP4 tensor
1 parent 257d18a commit d8792dd

File tree

3 files changed

+59
-12
lines changed

3 files changed

+59
-12
lines changed

torchao/prototype/custom_fp_utils.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,26 @@
88
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
99
# 1. No encodings are reserved for special values (+/-inf, NaN).
1010
# 2. When downcasting from FP32 to Floatx,
11-
# - Rounding mode is round to nearest, ties to even.
11+
# - Rounding mode is round to nearest, ties to even (default).
1212
# - Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx
1313
# magnitude (sign is preserved).
1414

15+
from enum import Enum
16+
1517
import torch
1618
from torch import Tensor
1719

1820

21+
class RoundingMode(Enum):
22+
"""Rounding modes for floating point quantization.
23+
24+
RN: Round to nearest, ties to even (default)
25+
RS: Stochastic rounding
26+
"""
27+
RN = "round_nearest"
28+
RS = "round_stochastic"
29+
30+
1931
def _n_ones(n: int) -> int:
2032
return (1 << n) - 1
2133

@@ -24,7 +36,9 @@ def _n_ones(n: int) -> int:
2436
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
2537

2638

27-
def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
39+
def _f32_to_floatx_unpacked(
40+
x: Tensor, ebits: int, mbits: int, rounding_mode: RoundingMode = RoundingMode.RN
41+
) -> Tensor:
2842
"""Convert FP32 numbers to sub-byte floating point numbers with the given
2943
number of exponent and mantissa bits.
3044
@@ -38,6 +52,12 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
3852
outside the representable range of Floatx after rounding are clamped to the
3953
maximum Floatx magnitude (sign is preserved).
4054
55+
Args:
56+
x: Input tensor of dtype torch.float
57+
ebits: Number of exponent bits
58+
mbits: Number of mantissa bits
59+
rounding_mode: Rounding mode to use (RN, RS)
60+
4161
Code below is an adaptation of https://fburl.com/code/ciwofcg4
4262
4363
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
@@ -111,13 +131,28 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
111131
# branch 3: stay in normal range, adjust the exponent and round
112132
#
113133
normal_x = x.view(torch.int32)
114-
# resulting mantissa is odd
115-
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
116-
# update exponent, rounding bias part 1
117-
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
118-
normal_x += val_to_add
119-
# rounding bias part 2
120-
normal_x += mant_odd
134+
val_to_add = (exp_bias - F32_EXP_BIAS) << MBITS_F32
135+
136+
if rounding_mode == RoundingMode.RN:
137+
# Round to nearest, ties to even
138+
# resulting mantissa is odd
139+
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
140+
# update exponent, rounding bias part 1
141+
val_to_add += magic_adder
142+
normal_x += val_to_add
143+
# rounding bias part 2
144+
normal_x += mant_odd
145+
elif rounding_mode == RoundingMode.RS:
146+
# Stochastic rounding
147+
# Add random bits to the discarded precision
148+
rnd = torch.randint_like(normal_x, 0, 1 << (MBITS_F32 - mbits), dtype=torch.int32)
149+
# update exponent
150+
normal_x += val_to_add
151+
# add randomness
152+
normal_x += rnd
153+
else:
154+
raise ValueError(f"Unsupported rounding mode: {rounding_mode}")
155+
121156
# take the bits!
122157
normal_x = normal_x >> (MBITS_F32 - mbits)
123158
normal_x = normal_x.to(torch.uint8)

torchao/prototype/mx_formats/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
NVFP4MMConfig,
1212
)
1313

14+
from torchao.prototype.custom_fp_utils import RoundingMode
15+
1416
# import mx_linear here to register the quantize_ transform logic
1517
# ruff: noqa: I001
1618
import torchao.prototype.mx_formats.mx_linear # noqa: F401
@@ -22,4 +24,5 @@
2224
"MXFPInferenceConfig",
2325
"NVFP4InferenceConfig",
2426
"NVFP4MMConfig",
27+
"RoundingMode",
2528
]

torchao/prototype/mx_formats/nvfp4_tensor.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
triton_quantize_nvfp4,
2121
unpack_uint4,
2222
)
23+
from torchao.prototype.custom_fp_utils import RoundingMode
2324
from torchao.prototype.mx_formats.mx_tensor import (
2425
tensor_size_fp4x2_to_hp,
2526
tensor_size_hp_to_fp4x2,
@@ -158,6 +159,7 @@ def to_nvfp4(
158159
is_swizzled_scales: bool = False,
159160
use_triton_kernel: bool = False,
160161
act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
162+
rounding_mode: RoundingMode = RoundingMode.RN,
161163
):
162164
"""Convert high precision tensor to NVFP4 format.
163165
@@ -171,6 +173,7 @@ def to_nvfp4(
171173
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
172174
use_triton_kernel: If True, use Triton kernel for quantization
173175
act_quant_kwargs: If specified, config for quantizing the activation
176+
rounding_mode: Rounding mode to use (RN for round-nearest, RS for stochastic)
174177
175178
Returns:
176179
NVFP4Tensor: Quantized tensor in NVFP4 format
@@ -183,10 +186,14 @@ def to_nvfp4(
183186
assert K % 16 == 0, (
184187
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
185188
)
186-
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
189+
# Convert RoundingMode enum to boolean for triton kernel
190+
use_stochastic_rounding = rounding_mode == RoundingMode.RS
191+
blockwise_scales, data_lp = triton_quantize_nvfp4(
192+
data_hp, per_tensor_scale, use_stochastic_rounding
193+
)
187194
else:
188195
blockwise_scales, data_lp = nvfp4_quantize(
189-
data_hp, block_size, per_tensor_scale
196+
data_hp, block_size, per_tensor_scale, rounding_mode
190197
)
191198
if is_swizzled_scales:
192199
scale_shape = (math.prod(leading_dims) * M, K // block_size)
@@ -677,6 +684,7 @@ def nvfp4_quantize(
677684
data_hp: torch.Tensor,
678685
block_size: int = 16,
679686
per_tensor_scale: Optional[torch.Tensor] = None,
687+
rounding_mode: RoundingMode = RoundingMode.RN,
680688
) -> tuple[torch.Tensor, torch.Tensor]:
681689
"""NVIDIA FP4 quantization with UE4M3 scales.
682690
@@ -688,6 +696,7 @@ def nvfp4_quantize(
688696
block_size: Block size for quantization (must be 16)
689697
per_tensor_amax: Optional pre-computed absolute maximum for calibration.
690698
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
699+
rounding_mode: Rounding mode to use (RN or RS)
691700
692701
Returns:
693702
tuple: A tuple containing:
@@ -742,7 +751,7 @@ def nvfp4_quantize(
742751

743752
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
744753
data_scaled = data_scaled.view(orig_shape)
745-
data_lp = f32_to_f4_unpacked(data_scaled)
754+
data_lp = f32_to_f4_unpacked(data_scaled, rounding_mode)
746755
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
747756
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
748757
data_lp = pack_uint4(data_lp)

0 commit comments

Comments
 (0)