Skip to content

Commit 70485c5

Browse files
committed
Implements rounding mode
1 parent 7bb7f06 commit 70485c5

6 files changed

Lines changed: 392 additions & 65 deletions

File tree

test/prototype/mx_formats/test_kernels.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from torch.utils._triton import has_triton
1010

11+
from torchao.prototype.custom_fp_utils import RoundingMode
1112
from torchao.prototype.mx_formats.constants import (
1213
DTYPE_FP6_E2M3,
1314
DTYPE_FP6_E3M2,
@@ -54,8 +55,78 @@
5455
if not torch_version_at_least("2.8.0"):
5556
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
5657

58+
if has_triton() and torch.cuda.is_available() and is_sm_at_least_100():
59+
import triton
60+
import triton.language as tl
61+
from torchao.prototype.mx_formats.kernels import (
62+
convert_fp32_to_fp4_packed,
63+
convert_fp32_to_fp4_packed_rs,
64+
)
65+
66+
@triton.jit
67+
def _triton_f4_pack_kernel(
68+
x_ptr, out_ptr, N, seed, ROUNDING_MODE: tl.constexpr,
69+
):
70+
"""Thin wrapper to test convert_fp32_to_fp4_packed{,_rs} in isolation."""
71+
pid = tl.program_id(0)
72+
offs = pid * 64 + tl.arange(0, 64)
73+
mask = offs < N
74+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
75+
x_pairs = x.reshape(32, 2).split()
76+
if ROUNDING_MODE == 0:
77+
x_fp4x2 = convert_fp32_to_fp4_packed(x_pairs)
78+
else:
79+
out_offs = pid * 32 + tl.arange(0, 32)
80+
rbits = tl.randint(seed, out_offs)
81+
x_fp4x2 = convert_fp32_to_fp4_packed_rs(x_pairs, rbits)
82+
out_offs = pid * 32 + tl.arange(0, 32)
83+
tl.store(out_ptr + out_offs, x_fp4x2, mask=out_offs < N // 2)
84+
85+
def triton_f4_pack(x, rounding_mode=RoundingMode.RN, seed=0):
86+
"""Pack FP32 values to FP4 using Triton convert_fp32_to_fp4_packed{,_rs}."""
87+
N = x.numel()
88+
out = torch.empty(N // 2, dtype=torch.uint8, device=x.device)
89+
grid = (triton.cdiv(N, 64),)
90+
_triton_f4_pack_kernel[grid](
91+
x, out, N, seed, ROUNDING_MODE=rounding_mode.value,
92+
)
93+
return out
94+
95+
FP4_RN_EXPECTED = [(5.2, 6.0), (-5.2, -6.0)]
96+
97+
_triton_kernel_params = [
98+
False,
99+
pytest.param(
100+
True,
101+
marks=pytest.mark.skipif(
102+
not (has_triton() and torch.cuda.is_available() and is_sm_at_least_100()),
103+
reason="Triton FP4 kernel requires CUDA capability 10.0 or greater",
104+
),
105+
),
106+
]
107+
108+
109+
def _f4_quantize(x, rounding_mode, use_triton, seed=None):
110+
"""Quantize FP32 to FP4 and dequantize, using either PyTorch or Triton kernel."""
111+
if rounding_mode not in RoundingMode:
112+
raise ValueError(
113+
f"Unknown rounding_mode: {rounding_mode}. "
114+
f"Expected RoundingMode.RN or RoundingMode.RS."
115+
)
116+
if seed is not None and not use_triton:
117+
torch.manual_seed(seed)
118+
if use_triton:
119+
if seed is None:
120+
seed = torch.randint(2**31, (1,)).item()
121+
xq = triton_f4_pack(x.flatten(), rounding_mode=rounding_mode, seed=seed)
122+
return f4_unpacked_to_f32(unpack_uint4(xq))
123+
else:
124+
return f4_unpacked_to_f32(f32_to_f4_unpacked(x, rounding_mode=rounding_mode))
125+
57126

58127
# TODO: shared utils file for benchmarking and testing
128+
129+
59130
def to_mx_dim1_reference(x_hp, block_size, scaling_mode):
60131
x_hp = x_hp.t().contiguous()
61132
scale_d1, data_d1 = to_mx(
@@ -624,3 +695,38 @@ def test_cuda_mx_dim0_not_supported():
624695
rowwise=True,
625696
colwise=False,
626697
)
698+
699+
700+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
701+
@pytest.mark.parametrize("use_triton", _triton_kernel_params)
702+
@pytest.mark.parametrize("rounding_mode", [RoundingMode.RN, RoundingMode.RS, 99])
703+
@pytest.mark.parametrize("seed_a,seed_b", [(42, 42), (42, 123)])
704+
@pytest.mark.parametrize("shape", [(1024, 128)])
705+
@pytest.mark.parametrize("value,rn_expected", FP4_RN_EXPECTED)
706+
def test_f4_rounding(value, rn_expected, shape, seed_a, seed_b, rounding_mode, use_triton):
707+
"""Test FP4 rounding: RN is biased, RS is unbiased, RS respects seed, invalid raises."""
708+
x = torch.ones(*shape, device="cuda", dtype=torch.bfloat16) * value
709+
710+
if rounding_mode not in RoundingMode:
711+
with pytest.raises(ValueError, match="Unknown rounding_mode"):
712+
_f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_a)
713+
return
714+
715+
rtol = 1e-2
716+
r1 = _f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_a)
717+
r2 = _f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_b)
718+
719+
# Check rounding behavior via mean
720+
r1_mean = torch.mean(r1)
721+
if rounding_mode == RoundingMode.RN:
722+
torch.testing.assert_close(r1_mean.item(), rn_expected, rtol=rtol, atol=rtol)
723+
else:
724+
input_mean = torch.mean(x.float())
725+
torch.testing.assert_close(r1_mean, input_mean, rtol=rtol, atol=rtol)
726+
727+
# Check seed determinism
728+
if seed_a == seed_b:
729+
torch.testing.assert_close(r1, r2, atol=0, rtol=0)
730+
elif rounding_mode == RoundingMode.RS:
731+
assert not torch.allclose(r1, r2)
732+

0 commit comments

Comments
 (0)