Skip to content

Commit 0ff45d6

Browse files
committed
Add CuTeDSL kernel for 3D tensor quantization to MXFP8
1 parent 960f307 commit 0ff45d6

4 files changed

Lines changed: 1091 additions & 17 deletions

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
from typing import Dict, Tuple
9+
10+
import torch
11+
12+
from torchao.prototype.moe_training.kernels.mxfp8.quant import mxfp8_quantize_cuda_3d
13+
14+
15+
def _parse_backends(s: str) -> list[str]:
16+
out = [x.strip() for x in s.split(",") if x.strip()]
17+
for b in out:
18+
if b not in {"cuda", "cutedsl"}:
19+
raise ValueError(f"Unsupported backend={b!r}, expected cuda/cutedsl")
20+
return out
21+
22+
23+
def _dtype_from_str(s: str) -> torch.dtype:
24+
if s == "bf16":
25+
return torch.bfloat16
26+
if s == "fp32":
27+
return torch.float32
28+
raise ValueError(f"Unsupported dtype={s}")
29+
30+
31+
def _tbps(num_bytes: int, ms: float) -> float:
32+
return num_bytes / (ms / 1e3) / 1e12
33+
34+
35+
def _benchmark(fn, warmup: int, iters: int) -> float:
36+
for _ in range(warmup):
37+
fn()
38+
torch.cuda.synchronize()
39+
start = torch.cuda.Event(enable_timing=True)
40+
end = torch.cuda.Event(enable_timing=True)
41+
start.record()
42+
for _ in range(iters):
43+
fn()
44+
end.record()
45+
torch.cuda.synchronize()
46+
return start.elapsed_time(end) / iters
47+
48+
49+
def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float:
50+
return (a.float() - b.float()).abs().max().item()
51+
52+
53+
def _run_3d(args) -> None:
54+
dtype = _dtype_from_str(args.dtype)
55+
backends = _parse_backends(args.backends)
56+
E, N, K = args.E, args.N, args.K
57+
58+
props = torch.cuda.get_device_properties(torch.cuda.current_device())
59+
cc = torch.cuda.get_device_capability()
60+
print(f"GPU: {props.name}")
61+
print(f"CC: {cc}")
62+
print(
63+
f"shape=(E,N,K)=({E},{N},{K}) dtype={dtype} scaling_mode={args.scaling_mode} "
64+
f"stage_count={args.stage_count}"
65+
)
66+
67+
x = torch.randn((E, N, K), device="cuda", dtype=dtype) * 1000
68+
bytes_moved = (
69+
x.numel() * x.element_size() # input
70+
+ x.numel()
71+
* torch.tensor([], dtype=torch.float8_e4m3fn).element_size() # output
72+
+ (E * (N // 32) * K)
73+
* torch.tensor([], dtype=torch.float8_e8m0fnu).element_size() # scale
74+
)
75+
76+
outs: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
77+
for b in backends:
78+
try:
79+
fn = lambda: mxfp8_quantize_cuda_3d( # noqa: E731
80+
x,
81+
block_size=32,
82+
scaling_mode=args.scaling_mode,
83+
backend=b,
84+
stage_count=args.stage_count,
85+
)
86+
ms = _benchmark(fn, args.warmup, args.iters)
87+
y, s = fn()
88+
outs[b] = (y, s)
89+
print(
90+
f"[{b:<10}] {ms:.3f} ms {_tbps(bytes_moved, ms):.3f} TB/s "
91+
f"y_stride={tuple(y.stride())} s_shape={tuple(s.shape)}"
92+
)
93+
except Exception as e:
94+
print(f"[{b:<10}] FAILED after 0.00s: {type(e).__name__}: {e}")
95+
96+
if args.check_results and "cuda" in outs:
97+
y_ref, s_ref = outs["cuda"]
98+
for b in backends:
99+
if b == "cuda" or b not in outs:
100+
continue
101+
y, s = outs[b]
102+
dy = _max_abs_diff(y_ref, y)
103+
ds = _max_abs_diff(s_ref, s)
104+
print(f"diff(cuda vs {b}): y_max_abs={dy} s_max_abs={ds}")
105+
ok = dy <= args.atol and ds <= args.atol
106+
print(f"check(cuda vs {b}): {'PASS' if ok else 'FAIL'} (atol={args.atol})")
107+
if not ok:
108+
raise RuntimeError(
109+
f"Result mismatch for backend={b}: y_diff={dy}, s_diff={ds}, atol={args.atol}"
110+
)
111+
112+
113+
def main() -> None:
114+
parser = argparse.ArgumentParser()
115+
parser.add_argument("--dtype", choices=("bf16", "fp32"), default="bf16")
116+
parser.add_argument("--scaling-mode", choices=("floor", "rceil"), default="floor")
117+
parser.add_argument("--backends", default="cuda,cutedsl")
118+
parser.add_argument("--warmup", type=int, default=20)
119+
parser.add_argument("--iters", type=int, default=100)
120+
parser.add_argument("--check-results", action="store_true")
121+
parser.add_argument("--atol", type=float, default=0.0)
122+
parser.add_argument("--stage-count", type=int, default=2)
123+
124+
parser.add_argument("--E", type=int, default=8)
125+
parser.add_argument("--N", type=int, default=7168)
126+
parser.add_argument("--K", type=int, default=2048)
127+
128+
args = parser.parse_args()
129+
_run_3d(args)
130+
131+
132+
if __name__ == "__main__":
133+
main()

test/prototype/moe_training/test_kernels.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
triton_fp8_per_group_rowwise_scales,
2929
)
3030
from torchao.prototype.moe_training.kernels.mxfp8 import (
31-
_mxfp8_cuda_kernels_available,
3231
fused_pad_token_groups_cuda,
3332
fused_unpad_token_groups_cuda,
3433
mx_block_rearrange_2d_M_groups_cuda,
@@ -42,6 +41,10 @@
4241
triton_mx_block_rearrange_2d_M_groups,
4342
triton_mx_block_rearrange_per_group_3d,
4443
)
44+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
45+
_mxfp8_cuda_kernels_available,
46+
_mxfp8_cutedsl_kernels_available,
47+
)
4548
from torchao.prototype.moe_training.utils import (
4649
_is_column_major,
4750
generate_jagged_offs,
@@ -375,7 +378,13 @@ def test_triton_mx_block_rearrange_2d_K_groups(
375378
@pytest.mark.parametrize("K", (32, 1536, 5120, 7168, 8192))
376379
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
377380
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
378-
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
381+
@pytest.mark.parametrize("backend", ("cuda", "cutedsl"))
382+
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode, backend):
383+
if backend == "cuda" and not _mxfp8_cuda_kernels_available:
384+
pytest.skip("CUDA C++ mxfp8_quantize_3d backend is unavailable")
385+
if backend == "cutedsl" and not _mxfp8_cutedsl_kernels_available:
386+
pytest.skip("CuTeDSL mxfp8_quantize_3d backend is unavailable")
387+
379388
scaling_mode_str = (
380389
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
381390
)
@@ -402,7 +411,10 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
402411
s_d1_ref = s_d1_ref.transpose(-2, -1)
403412

404413
y_d1, s_d1 = mxfp8_quantize_cuda_3d(
405-
x, block_size=block_size, scaling_mode=scaling_mode_str
414+
x,
415+
block_size=block_size,
416+
scaling_mode=scaling_mode_str,
417+
backend=backend,
406418
)
407419
# Check scales
408420
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)