|
8 | 8 | import torch |
9 | 9 | from torch.utils._triton import has_triton |
10 | 10 |
|
| 11 | +from torchao.prototype.custom_fp_utils import RoundingMode |
11 | 12 | from torchao.prototype.mx_formats.constants import ( |
12 | 13 | DTYPE_FP6_E2M3, |
13 | 14 | DTYPE_FP6_E3M2, |
|
54 | 55 | if not torch_version_at_least("2.8.0"): |
55 | 56 | pytest.skip("Unsupported PyTorch version", allow_module_level=True) |
56 | 57 |
|
| 58 | +if has_triton() and torch.cuda.is_available() and is_sm_at_least_100(): |
| 59 | + import triton |
| 60 | + import triton.language as tl |
| 61 | + from torchao.prototype.mx_formats.kernels import ( |
| 62 | + convert_fp32_to_fp4_packed, |
| 63 | + convert_fp32_to_fp4_packed_rs, |
| 64 | + ) |
| 65 | + |
| 66 | + @triton.jit |
| 67 | + def _triton_f4_pack_kernel( |
| 68 | + x_ptr, out_ptr, N, seed, ROUNDING_MODE: tl.constexpr, |
| 69 | + ): |
| 70 | + """Thin wrapper to test convert_fp32_to_fp4_packed{,_rs} in isolation.""" |
| 71 | + pid = tl.program_id(0) |
| 72 | + offs = pid * 64 + tl.arange(0, 64) |
| 73 | + mask = offs < N |
| 74 | + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) |
| 75 | + x_pairs = x.reshape(32, 2).split() |
| 76 | + if ROUNDING_MODE == 0: |
| 77 | + x_fp4x2 = convert_fp32_to_fp4_packed(x_pairs) |
| 78 | + else: |
| 79 | + out_offs = pid * 32 + tl.arange(0, 32) |
| 80 | + rbits = tl.randint(seed, out_offs) |
| 81 | + x_fp4x2 = convert_fp32_to_fp4_packed_rs(x_pairs, rbits) |
| 82 | + out_offs = pid * 32 + tl.arange(0, 32) |
| 83 | + tl.store(out_ptr + out_offs, x_fp4x2, mask=out_offs < N // 2) |
| 84 | + |
| 85 | + def triton_f4_pack(x, rounding_mode=RoundingMode.RN, seed=0): |
| 86 | + """Pack FP32 values to FP4 using Triton convert_fp32_to_fp4_packed{,_rs}.""" |
| 87 | + N = x.numel() |
| 88 | + out = torch.empty(N // 2, dtype=torch.uint8, device=x.device) |
| 89 | + grid = (triton.cdiv(N, 64),) |
| 90 | + _triton_f4_pack_kernel[grid]( |
| 91 | + x, out, N, seed, ROUNDING_MODE=rounding_mode.value, |
| 92 | + ) |
| 93 | + return out |
| 94 | + |
| 95 | +FP4_RN_EXPECTED = [(5.2, 6.0), (-5.2, -6.0)] |
| 96 | + |
| 97 | +_triton_kernel_params = [ |
| 98 | + False, |
| 99 | + pytest.param( |
| 100 | + True, |
| 101 | + marks=pytest.mark.skipif( |
| 102 | + not (has_triton() and torch.cuda.is_available() and is_sm_at_least_100()), |
| 103 | + reason="Triton FP4 kernel requires CUDA capability 10.0 or greater", |
| 104 | + ), |
| 105 | + ), |
| 106 | +] |
| 107 | + |
| 108 | + |
| 109 | +def _f4_quantize(x, rounding_mode, use_triton, seed=None): |
| 110 | + """Quantize FP32 to FP4 and dequantize, using either PyTorch or Triton kernel.""" |
| 111 | + if rounding_mode not in RoundingMode: |
| 112 | + raise ValueError( |
| 113 | + f"Unknown rounding_mode: {rounding_mode}. " |
| 114 | + f"Expected RoundingMode.RN or RoundingMode.RS." |
| 115 | + ) |
| 116 | + if seed is not None and not use_triton: |
| 117 | + torch.manual_seed(seed) |
| 118 | + if use_triton: |
| 119 | + if seed is None: |
| 120 | + seed = torch.randint(2**31, (1,)).item() |
| 121 | + xq = triton_f4_pack(x.flatten(), rounding_mode=rounding_mode, seed=seed) |
| 122 | + return f4_unpacked_to_f32(unpack_uint4(xq)) |
| 123 | + else: |
| 124 | + return f4_unpacked_to_f32(f32_to_f4_unpacked(x, rounding_mode=rounding_mode)) |
| 125 | + |
57 | 126 |
|
58 | 127 | # TODO: shared utils file for benchmarking and testing |
| 128 | + |
| 129 | + |
59 | 130 | def to_mx_dim1_reference(x_hp, block_size, scaling_mode): |
60 | 131 | x_hp = x_hp.t().contiguous() |
61 | 132 | scale_d1, data_d1 = to_mx( |
@@ -624,3 +695,38 @@ def test_cuda_mx_dim0_not_supported(): |
624 | 695 | rowwise=True, |
625 | 696 | colwise=False, |
626 | 697 | ) |
| 698 | + |
| 699 | + |
| 700 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") |
| 701 | +@pytest.mark.parametrize("use_triton", _triton_kernel_params) |
| 702 | +@pytest.mark.parametrize("rounding_mode", [RoundingMode.RN, RoundingMode.RS, 99]) |
| 703 | +@pytest.mark.parametrize("seed_a,seed_b", [(42, 42), (42, 123)]) |
| 704 | +@pytest.mark.parametrize("shape", [(1024, 128)]) |
| 705 | +@pytest.mark.parametrize("value,rn_expected", FP4_RN_EXPECTED) |
| 706 | +def test_f4_rounding(value, rn_expected, shape, seed_a, seed_b, rounding_mode, use_triton): |
| 707 | + """Test FP4 rounding: RN is biased, RS is unbiased, RS respects seed, invalid raises.""" |
| 708 | + x = torch.ones(*shape, device="cuda", dtype=torch.bfloat16) * value |
| 709 | + |
| 710 | + if rounding_mode not in RoundingMode: |
| 711 | + with pytest.raises(ValueError, match="Unknown rounding_mode"): |
| 712 | + _f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_a) |
| 713 | + return |
| 714 | + |
| 715 | + rtol = 1e-2 |
| 716 | + r1 = _f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_a) |
| 717 | + r2 = _f4_quantize(x.float(), rounding_mode, use_triton, seed=seed_b) |
| 718 | + |
| 719 | + # Check rounding behavior via mean |
| 720 | + r1_mean = torch.mean(r1) |
| 721 | + if rounding_mode == RoundingMode.RN: |
| 722 | + torch.testing.assert_close(r1_mean.item(), rn_expected, rtol=rtol, atol=rtol) |
| 723 | + else: |
| 724 | + input_mean = torch.mean(x.float()) |
| 725 | + torch.testing.assert_close(r1_mean, input_mean, rtol=rtol, atol=rtol) |
| 726 | + |
| 727 | + # Check seed determinism |
| 728 | + if seed_a == seed_b: |
| 729 | + torch.testing.assert_close(r1, r2, atol=0, rtol=0) |
| 730 | + elif rounding_mode == RoundingMode.RS: |
| 731 | + assert not torch.allclose(r1, r2) |
| 732 | + |
0 commit comments