|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import argparse |
| 8 | +from typing import Dict, Tuple |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torchao.prototype.moe_training.kernels.mxfp8.quant import mxfp8_quantize_cuda_3d |
| 13 | + |
| 14 | + |
| 15 | +def _parse_backends(s: str) -> list[str]: |
| 16 | + out = [x.strip() for x in s.split(",") if x.strip()] |
| 17 | + for b in out: |
| 18 | + if b not in {"cuda", "cutedsl"}: |
| 19 | + raise ValueError(f"Unsupported backend={b!r}, expected cuda/cutedsl") |
| 20 | + return out |
| 21 | + |
| 22 | + |
| 23 | +def _dtype_from_str(s: str) -> torch.dtype: |
| 24 | + if s == "bf16": |
| 25 | + return torch.bfloat16 |
| 26 | + if s == "fp32": |
| 27 | + return torch.float32 |
| 28 | + raise ValueError(f"Unsupported dtype={s}") |
| 29 | + |
| 30 | + |
| 31 | +def _tbps(num_bytes: int, ms: float) -> float: |
| 32 | + return num_bytes / (ms / 1e3) / 1e12 |
| 33 | + |
| 34 | + |
| 35 | +def _benchmark(fn, warmup: int, iters: int) -> float: |
| 36 | + for _ in range(warmup): |
| 37 | + fn() |
| 38 | + torch.cuda.synchronize() |
| 39 | + start = torch.cuda.Event(enable_timing=True) |
| 40 | + end = torch.cuda.Event(enable_timing=True) |
| 41 | + start.record() |
| 42 | + for _ in range(iters): |
| 43 | + fn() |
| 44 | + end.record() |
| 45 | + torch.cuda.synchronize() |
| 46 | + return start.elapsed_time(end) / iters |
| 47 | + |
| 48 | + |
| 49 | +def _max_abs_diff(a: torch.Tensor, b: torch.Tensor) -> float: |
| 50 | + return (a.float() - b.float()).abs().max().item() |
| 51 | + |
| 52 | + |
| 53 | +def _run_3d(args) -> None: |
| 54 | + dtype = _dtype_from_str(args.dtype) |
| 55 | + backends = _parse_backends(args.backends) |
| 56 | + E, N, K = args.E, args.N, args.K |
| 57 | + |
| 58 | + props = torch.cuda.get_device_properties(torch.cuda.current_device()) |
| 59 | + cc = torch.cuda.get_device_capability() |
| 60 | + print(f"GPU: {props.name}") |
| 61 | + print(f"CC: {cc}") |
| 62 | + print( |
| 63 | + f"shape=(E,N,K)=({E},{N},{K}) dtype={dtype} scaling_mode={args.scaling_mode} " |
| 64 | + f"stage_count={args.stage_count}" |
| 65 | + ) |
| 66 | + |
| 67 | + x = torch.randn((E, N, K), device="cuda", dtype=dtype) * 1000 |
| 68 | + bytes_moved = ( |
| 69 | + x.numel() * x.element_size() # input |
| 70 | + + x.numel() |
| 71 | + * torch.tensor([], dtype=torch.float8_e4m3fn).element_size() # output |
| 72 | + + (E * (N // 32) * K) |
| 73 | + * torch.tensor([], dtype=torch.float8_e8m0fnu).element_size() # scale |
| 74 | + ) |
| 75 | + |
| 76 | + outs: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} |
| 77 | + for b in backends: |
| 78 | + try: |
| 79 | + fn = lambda: mxfp8_quantize_cuda_3d( # noqa: E731 |
| 80 | + x, |
| 81 | + block_size=32, |
| 82 | + scaling_mode=args.scaling_mode, |
| 83 | + backend=b, |
| 84 | + stage_count=args.stage_count, |
| 85 | + ) |
| 86 | + ms = _benchmark(fn, args.warmup, args.iters) |
| 87 | + y, s = fn() |
| 88 | + outs[b] = (y, s) |
| 89 | + print( |
| 90 | + f"[{b:<10}] {ms:.3f} ms {_tbps(bytes_moved, ms):.3f} TB/s " |
| 91 | + f"y_stride={tuple(y.stride())} s_shape={tuple(s.shape)}" |
| 92 | + ) |
| 93 | + except Exception as e: |
| 94 | + print(f"[{b:<10}] FAILED after 0.00s: {type(e).__name__}: {e}") |
| 95 | + |
| 96 | + if args.check_results and "cuda" in outs: |
| 97 | + y_ref, s_ref = outs["cuda"] |
| 98 | + for b in backends: |
| 99 | + if b == "cuda" or b not in outs: |
| 100 | + continue |
| 101 | + y, s = outs[b] |
| 102 | + dy = _max_abs_diff(y_ref, y) |
| 103 | + ds = _max_abs_diff(s_ref, s) |
| 104 | + print(f"diff(cuda vs {b}): y_max_abs={dy} s_max_abs={ds}") |
| 105 | + ok = dy <= args.atol and ds <= args.atol |
| 106 | + print(f"check(cuda vs {b}): {'PASS' if ok else 'FAIL'} (atol={args.atol})") |
| 107 | + if not ok: |
| 108 | + raise RuntimeError( |
| 109 | + f"Result mismatch for backend={b}: y_diff={dy}, s_diff={ds}, atol={args.atol}" |
| 110 | + ) |
| 111 | + |
| 112 | + |
| 113 | +def main() -> None: |
| 114 | + parser = argparse.ArgumentParser() |
| 115 | + parser.add_argument("--dtype", choices=("bf16", "fp32"), default="bf16") |
| 116 | + parser.add_argument("--scaling-mode", choices=("floor", "rceil"), default="floor") |
| 117 | + parser.add_argument("--backends", default="cuda,cutedsl") |
| 118 | + parser.add_argument("--warmup", type=int, default=20) |
| 119 | + parser.add_argument("--iters", type=int, default=100) |
| 120 | + parser.add_argument("--check-results", action="store_true") |
| 121 | + parser.add_argument("--atol", type=float, default=0.0) |
| 122 | + parser.add_argument("--stage-count", type=int, default=2) |
| 123 | + |
| 124 | + parser.add_argument("--E", type=int, default=8) |
| 125 | + parser.add_argument("--N", type=int, default=7168) |
| 126 | + parser.add_argument("--K", type=int, default=2048) |
| 127 | + |
| 128 | + args = parser.parse_args() |
| 129 | + _run_3d(args) |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + main() |
0 commit comments