88# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
99# 1. No encodings are reserved for special values (+/-inf, NaN).
1010# 2. When downcasting from FP32 to Floatx,
11- # - Rounding mode is round to nearest, ties to even.
11+ # - Rounding mode is round to nearest, ties to even (default) .
1212# - Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx
1313# magnitude (sign is preserved).
1414
15+ from enum import Enum
16+
1517import torch
1618from torch import Tensor
1719
1820
21+ class RoundingMode (Enum ):
22+ """Rounding modes for floating point quantization.
23+
24+ RN: Round to nearest, ties to even (default)
25+ RS: Stochastic rounding
26+ """
27+ RN = "round_nearest"
28+ RS = "round_stochastic"
29+
30+
1931def _n_ones (n : int ) -> int :
2032 return (1 << n ) - 1
2133
@@ -24,7 +36,9 @@ def _n_ones(n: int) -> int:
2436F32_EXP_BIAS = _n_ones (EBITS_F32 - 1 )
2537
2638
27- def _f32_to_floatx_unpacked (x : Tensor , ebits : int , mbits : int ) -> Tensor :
39+ def _f32_to_floatx_unpacked (
40+ x : Tensor , ebits : int , mbits : int , rounding_mode : RoundingMode = RoundingMode .RN
41+ ) -> Tensor :
2842 """Convert FP32 numbers to sub-byte floating point numbers with the given
2943 number of exponent and mantissa bits.
3044
@@ -38,6 +52,12 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
3852 outside the representable range of Floatx after rounding are clamped to the
3953 maximum Floatx magnitude (sign is preserved).
4054
55+ Args:
56+ x: Input tensor of dtype torch.float
57+ ebits: Number of exponent bits
58+ mbits: Number of mantissa bits
59+ rounding_mode: Rounding mode to use (RN, RS)
60+
4161 Code below is an adaptation of https://fburl.com/code/ciwofcg4
4262
4363 Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
@@ -111,13 +131,28 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
111131 # branch 3: stay in normal range, adjust the exponent and round
112132 #
113133 normal_x = x .view (torch .int32 )
114- # resulting mantissa is odd
115- mant_odd = (normal_x >> (MBITS_F32 - mbits )) & 1
116- # update exponent, rounding bias part 1
117- val_to_add = ((exp_bias - F32_EXP_BIAS ) << MBITS_F32 ) + magic_adder
118- normal_x += val_to_add
119- # rounding bias part 2
120- normal_x += mant_odd
134+ val_to_add = (exp_bias - F32_EXP_BIAS ) << MBITS_F32
135+
136+ if rounding_mode == RoundingMode .RN :
137+ # Round to nearest, ties to even
138+ # resulting mantissa is odd
139+ mant_odd = (normal_x >> (MBITS_F32 - mbits )) & 1
140+ # update exponent, rounding bias part 1
141+ val_to_add += magic_adder
142+ normal_x += val_to_add
143+ # rounding bias part 2
144+ normal_x += mant_odd
145+ elif rounding_mode == RoundingMode .RS :
146+ # Stochastic rounding
147+ # Add random bits to the discarded precision
148+ rnd = torch .randint_like (normal_x , 0 , 1 << (MBITS_F32 - mbits ), dtype = torch .int32 )
149+ # update exponent
150+ normal_x += val_to_add
151+ # add randomness
152+ normal_x += rnd
153+ else :
154+ raise ValueError (f"Unsupported rounding mode: { rounding_mode } " )
155+
121156 # take the bits!
122157 normal_x = normal_x >> (MBITS_F32 - mbits )
123158 normal_x = normal_x .to (torch .uint8 )
0 commit comments