Skip to content

Commit b669575

Browse files
committed
[mxfp8 moe training] add fused dim0+dim1 mxfp8 quantize kernel for backward pass
Adds ``triton_mxfp8_quantize_dim0_dim1`` (new file ``kernels/mxfp8/triton_grad_quantize.py``): a single-pass Triton kernel that reads the bf16 ``grad_out`` tile once and emits four outputs in one launch: * ``qdata_dim0`` - ``(M, N)`` e4m3 row-major, rowwise scales along N * ``qdata_dim1_t``- ``(N, M)`` e4m3 row-major (logical transpose of the colwise-quantized tile), colwise scales along M * both e8m0 scale tensors already in the tcgen05 blocked layout (``triton_mx_block_rearrange`` output, no separate rearrange pass needed) Replaces the current backward-pass sequence of ``triton_to_mxfp8_dim0 + triton_to_mxfp8_dim1 + 2x triton_mx_block_rearrange_2d_M_groups`` (four Triton launches, bf16 read twice, scale tensor written twice). Key implementation points: * Single bf16 tile load; both rowwise and colwise scales are computed via pure 3D reshapes of the SAME tile (no bf16 transpose). * No ``tl.trans`` on fp8 either: the dim1 output is stored with ``offset = n * M + m`` directly on the (BLOCK_M, BLOCK_N) tile, matching how the existing ``to_mxfp8_dim1_kernel`` handles its col-major write. * Scales are emitted directly in tcgen05 blocked layout by computing the ``(r % 32) * 16 + (r // 32) * 4 + c`` super-tile offset for every scale element and issuing one ``tl.store`` per scale tensor. * Autotuned over BLOCK_N in {128, 256, 512} x num_warps in {4, 8, 16} x num_stages in {2, 3, 4} with an ``early_config_prune`` that drops configs where BLOCK_N > N. Numerics: new ``test_triton_mxfp8_quantize_dim0_dim1_numerics`` parametrized over 7 shapes x {rceil, floor} asserts bit-exact parity of both fp8 data tensors and both blocked scales against the 4-kernel reference pipeline. 14/14 pass on B200 (SM 10.0). Benchmark (``bench_triton_grad_quantize.py``; do_bench median, B200, rceil): M N 4k_us fused_us 4k_GB/s fused_GB/s speedup 16384 2048 340.0 120.9 401 1128 2.81x 4096 2048 271.8 39.0 125 873 6.96x 8192 2048 291.4 65.6 234 1038 4.44x 32768 2048 451.5 233.5 604 1167 1.93x 16384 5120 459.6 292.8 742 1164 1.57x 16384 7168 547.8 405.5 871 1177 1.35x 8192 5120 358.4 151.6 475 1124 2.36x 8192 7168 394.2 209.9 605 1136 1.88x 32768 5120 761.7 574.6 895 1186 1.33x v1 lands ~1.1-1.2 TB/s (~20-32% of B200 bf16 memcpy, which measures ~5 TB/s on this rig). That is a consistent 1.33-6.96x over the 4-kernel baseline but still short of the "90% memcpy BW" bar -- the kernel is correct and drop-in-compatible, and the remaining headroom is in TMA / warp-specialized stores which are worth their own follow-up diff. Made-with: Cursor
1 parent 77da665 commit b669575

4 files changed

Lines changed: 671 additions & 0 deletions

File tree

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Benchmark: fused ``triton_mxfp8_quantize_dim0_dim1`` vs. today's 4-kernel
8+
backward-pass sequence (dim0 quantize + dim0 rearrange + dim1 quantize + dim1
9+
rearrange).
10+
11+
Target (review feedback on ao#4293): close the gap to the B200 bf16->fp8
12+
memcpy ceiling at the DeepSeek-V3-like backward-pass shape
13+
``(num_groups=4, M_per_group=4096, N=2048)`` -> ``(total_M=16384, N=2048)``
14+
and the adjacent sweep. Landing bar is >= 5 TB/s on the realistic shapes.
15+
Measured B200 bf16 memcpy is ~5 TB/s on this rig (the ``_memcpy_bf16_bw_gbps``
16+
helper below measures it live), so ``% memcpy BW`` is the relevant metric.
17+
"""
18+
19+
import argparse
20+
from dataclasses import dataclass
21+
from typing import List
22+
23+
import torch
24+
from tabulate import tabulate
25+
from tqdm import tqdm
26+
27+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
28+
from torchao.prototype.moe_training.kernels.mxfp8 import (
29+
triton_mx_block_rearrange_2d_M_groups,
30+
triton_mxfp8_quantize_dim0_dim1,
31+
)
32+
from torchao.prototype.mx_formats.kernels import (
33+
triton_to_mxfp8_dim0,
34+
triton_to_mxfp8_dim1,
35+
)
36+
from torchao.utils import ceil_div
37+
38+
device = torch.device("cuda")
39+
40+
torch._dynamo.config.cache_size_limit = 1000
41+
42+
43+
@dataclass(frozen=True)
44+
class ExperimentConfig:
45+
M: int
46+
N: int
47+
48+
49+
@dataclass(frozen=True)
50+
class ExperimentResult:
51+
four_kernel_us: float
52+
fused_us: float
53+
four_kernel_gbps: float
54+
fused_gbps: float
55+
speedup: float
56+
memcpy_bw_pct: float
57+
58+
59+
@dataclass(frozen=True)
60+
class Experiment:
61+
config: ExperimentConfig
62+
result: ExperimentResult
63+
64+
65+
def get_configs() -> List[ExperimentConfig]:
66+
# Daniel's target shape first, then a sweep covering realistic backward
67+
# pass sizes: (num_groups, M_per_group, N) in {(4, 4096, 2048), (4, 4096,
68+
# 5120), (8, 4096, 2048), (4, 8192, 2048), ...} as flat (total_M, N).
69+
pairs = [
70+
# (total_M, N) - Daniel's primary landing target
71+
(16384, 2048),
72+
# DeepSeek-V3-like sweeps
73+
(4096, 2048),
74+
(8192, 2048),
75+
(32768, 2048),
76+
(16384, 5120),
77+
(16384, 7168),
78+
(8192, 5120),
79+
(8192, 7168),
80+
(32768, 5120),
81+
]
82+
return [ExperimentConfig(M=m, N=n) for m, n in pairs]
83+
84+
85+
def _four_kernel_reference(x: torch.Tensor):
86+
"""Today's backward-pass path: dim0 quant + dim0 rearrange + dim1 quant +
87+
dim1 rearrange. Runs the same four Triton kernels that the production
88+
``_MXFP8GroupedMM.backward`` invokes today (minus the cross-group
89+
bookkeeping, which is the same constant overhead on either side)."""
90+
M, N = x.shape
91+
qdata0, scales0_rm = triton_to_mxfp8_dim0(
92+
x, inner_block_size=32, scaling_mode="rceil"
93+
)
94+
qdata1_t, scales1_rm = triton_to_mxfp8_dim1(
95+
x, inner_block_size=32, scaling_mode="rceil"
96+
)
97+
# We pass a single-group offset so the rearrange kernel still runs its
98+
# standard launch, matching the production pad-and-blocked-scale path.
99+
one_group_m = torch.tensor([M], dtype=torch.int32, device=x.device)
100+
one_group_n = torch.tensor([N], dtype=torch.int32, device=x.device)
101+
scales0_b = triton_mx_block_rearrange_2d_M_groups(scales0_rm, one_group_m)
102+
scales1_b = triton_mx_block_rearrange_2d_M_groups(scales1_rm, one_group_n)
103+
return qdata0, qdata1_t, scales0_b, scales1_b
104+
105+
106+
def _bytes_touched(M: int, N: int) -> int:
107+
"""Bytes of HBM that MUST flow for this problem: one bf16 read of the
108+
input, two fp8 writes (row-major + transposed), and two e8m0 scale
109+
writes in blocked layout. This is the memcpy lower bound."""
110+
scale_cols_n = ceil_div(N // 32, 4) * 4
111+
scale_cols_m = ceil_div(M // 32, 4) * 4
112+
return (
113+
M * N * 2 # bf16 read
114+
+ M * N * 1 # fp8 dim0 write
115+
+ N * M * 1 # fp8 dim1_t write
116+
+ M * scale_cols_n * 1 # e8m0 dim0 scales (blocked)
117+
+ N * scale_cols_m * 1 # e8m0 dim1 scales (blocked)
118+
)
119+
120+
121+
def _memcpy_bf16_bw_gbps(M: int, N: int) -> float:
122+
"""Rough B200 bf16 memcpy bandwidth estimate (read + write of an (M, N)
123+
bf16 tensor, sustained). Measured in-process so the number reflects the
124+
current machine, not a datasheet constant."""
125+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
126+
127+
def memcpy():
128+
return x.clone()
129+
130+
for _ in range(5):
131+
memcpy()
132+
us = benchmark_cuda_function_in_microseconds(memcpy)
133+
# clone reads and writes the full tensor -> 2 * M * N * 2 bytes.
134+
bytes_touched = 2 * M * N * 2
135+
return (bytes_touched / 1e9) / (us / 1e6)
136+
137+
138+
def run_experiment(
139+
config: ExperimentConfig, args: argparse.Namespace
140+
) -> ExperimentResult:
141+
M, N = config.M, config.N
142+
143+
torch.manual_seed(42)
144+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
145+
146+
def run_4kernel():
147+
return _four_kernel_reference(x)
148+
149+
def run_fused():
150+
return triton_mxfp8_quantize_dim0_dim1(x, scaling_mode="rceil")
151+
152+
for _ in range(5):
153+
run_4kernel()
154+
run_fused()
155+
156+
four_kernel_us = benchmark_cuda_function_in_microseconds(run_4kernel)
157+
fused_us = benchmark_cuda_function_in_microseconds(run_fused)
158+
159+
bytes_total = _bytes_touched(M, N)
160+
four_kernel_gbps = (bytes_total / 1e9) / (four_kernel_us / 1e6)
161+
fused_gbps = (bytes_total / 1e9) / (fused_us / 1e6)
162+
163+
# Express fused BW as a % of the bf16 memcpy ceiling at the same input
164+
# size. This is the "90%+ memcpy BW" metric Daniel asked for.
165+
memcpy_gbps = _memcpy_bf16_bw_gbps(M, N)
166+
memcpy_bw_pct = 100.0 * fused_gbps / memcpy_gbps if memcpy_gbps > 0 else 0.0
167+
168+
return ExperimentResult(
169+
four_kernel_us=four_kernel_us,
170+
fused_us=fused_us,
171+
four_kernel_gbps=four_kernel_gbps,
172+
fused_gbps=fused_gbps,
173+
speedup=four_kernel_us / fused_us,
174+
memcpy_bw_pct=memcpy_bw_pct,
175+
)
176+
177+
178+
def print_results(experiments: List[Experiment]):
179+
headers = [
180+
"M",
181+
"N",
182+
"4k_us",
183+
"fused_us",
184+
"4k_GB/s",
185+
"fused_GB/s",
186+
"speedup",
187+
"% memcpy BW",
188+
]
189+
rows = []
190+
for exp in experiments:
191+
rows.append(
192+
[
193+
exp.config.M,
194+
exp.config.N,
195+
f"{exp.result.four_kernel_us:.1f}",
196+
f"{exp.result.fused_us:.1f}",
197+
f"{exp.result.four_kernel_gbps:.0f}",
198+
f"{exp.result.fused_gbps:.0f}",
199+
f"{exp.result.speedup:.2f}x",
200+
f"{exp.result.memcpy_bw_pct:.1f}%",
201+
]
202+
)
203+
print(tabulate(rows, headers=headers))
204+
205+
206+
def main(args: argparse.Namespace):
207+
torch.random.manual_seed(123)
208+
configs = get_configs()
209+
results = []
210+
for config in tqdm(configs):
211+
result = run_experiment(config, args)
212+
results.append(Experiment(config=config, result=result))
213+
print_results(results)
214+
215+
216+
if __name__ == "__main__":
217+
parser = argparse.ArgumentParser()
218+
parser.add_argument(
219+
"--profile", action="store_true", help="Enable profiling with PyTorch profiler"
220+
)
221+
args = parser.parse_args()
222+
main(args)

test/prototype/moe_training/test_kernels.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def _is_sm_10x() -> bool:
5151
triton_mx_block_rearrange_per_group_3d,
5252
triton_mxfp8_dispatch_and_quantize,
5353
triton_mxfp8_pad_and_quantize,
54+
triton_mxfp8_quantize_dim0_dim1,
5455
)
5556
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
5657
_mxfp8_cuda_kernels_available,
@@ -885,6 +886,95 @@ def test_triton_mxfp8_dispatch_and_quantize_numerics(
885886
)
886887

887888

889+
@pytest.mark.skipif(
890+
not _is_sm_10x(),
891+
reason="requires CUDA SM 10.x (blocked scale GEMM hw)",
892+
)
893+
@skip_if_rocm("ROCm enablement in progress")
894+
@pytest.mark.parametrize(
895+
"M,N",
896+
[
897+
(128, 128),
898+
(256, 512),
899+
(1024, 2048),
900+
(4096, 2048),
901+
(8192, 2048),
902+
(16384, 2048),
903+
(2048, 5120),
904+
],
905+
)
906+
@pytest.mark.parametrize("scaling_mode_str", ["rceil", "floor"])
907+
def test_triton_mxfp8_quantize_dim0_dim1_numerics(
908+
M: int, N: int, scaling_mode_str: str
909+
):
910+
"""Fused dim0+dim1 MXFP8 quantization with blocked scales should be
911+
bit-exactly equivalent to the decoupled 4-kernel reference pipeline:
912+
913+
qdata0, scales0_rm = triton_to_mxfp8_dim0(x, 32, mode)
914+
qdata1_t, scales1_rm = triton_to_mxfp8_dim1(x, 32, mode)
915+
scales0_blocked = triton_mx_block_rearrange_2d_M_groups(scales0_rm, [M])
916+
scales1_blocked = triton_mx_block_rearrange_2d_M_groups(scales1_rm, [N])
917+
"""
918+
from torchao.prototype.mx_formats.kernels import (
919+
triton_to_mxfp8_dim0,
920+
triton_to_mxfp8_dim1,
921+
)
922+
923+
device = "cuda"
924+
torch.manual_seed(2024)
925+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
926+
927+
# Fused kernel under test.
928+
qdata0_fused, qdata1_t_fused, scales0_fused, scales1_fused = (
929+
triton_mxfp8_quantize_dim0_dim1(x, scaling_mode=scaling_mode_str)
930+
)
931+
932+
# Reference dim0 pipeline: quantize along N, then blocked-rearrange.
933+
qdata0_ref, scales0_ref_rm = triton_to_mxfp8_dim0(
934+
x, inner_block_size=32, scaling_mode=scaling_mode_str
935+
)
936+
one_group_offsets_m = torch.tensor([M], dtype=torch.int32, device=device)
937+
scales0_blocked_ref = triton_mx_block_rearrange_2d_M_groups(
938+
scales0_ref_rm, one_group_offsets_m
939+
)
940+
941+
# Reference dim1 pipeline: quantize along M (returns transposed data),
942+
# then blocked-rearrange.
943+
qdata1_t_ref, scales1_ref_rm = triton_to_mxfp8_dim1(
944+
x, inner_block_size=32, scaling_mode=scaling_mode_str
945+
)
946+
one_group_offsets_n = torch.tensor([N], dtype=torch.int32, device=device)
947+
scales1_blocked_ref = triton_mx_block_rearrange_2d_M_groups(
948+
scales1_ref_rm, one_group_offsets_n
949+
)
950+
951+
# qdata_dim0: (M, N) row-major e4m3 - bit-exact parity.
952+
assert qdata0_fused.shape == (M, N)
953+
assert qdata0_fused.dtype == torch.float8_e4m3fn
954+
assert torch.equal(
955+
qdata0_fused.view(torch.uint8), qdata0_ref.view(torch.uint8)
956+
), "fused dim0 fp8 data differs from triton_to_mxfp8_dim0 reference"
957+
958+
# qdata_dim1_t: (N, M) row-major e4m3 - bit-exact parity vs.
959+
# transposed dim1 reference. ``triton_to_mxfp8_dim1`` returns its data
960+
# as a ``.t()`` view of an (N, M) row-major column-major-of-x tensor,
961+
# so calling ``.t()`` on the reference peels the view back to the raw
962+
# (N, M) row-major storage we produce.
963+
assert qdata1_t_fused.shape == (N, M)
964+
assert qdata1_t_fused.dtype == torch.float8_e4m3fn
965+
qdata1_t_ref_rowmajor = qdata1_t_ref.t().contiguous()
966+
assert torch.equal(
967+
qdata1_t_fused.view(torch.uint8),
968+
qdata1_t_ref_rowmajor.view(torch.uint8),
969+
), "fused dim1-transpose fp8 data differs from triton_to_mxfp8_dim1 reference"
970+
971+
# Blocked scales for dim0: compare via from_blocked canonical view
972+
# (128x4 blocks may have gap cells that are legitimately uninitialized).
973+
_assert_blocked_scales_equal(scales0_fused, scales0_blocked_ref, M, N // 32)
974+
# Blocked scales for dim1 live in an (N, M/32) logical tensor.
975+
_assert_blocked_scales_equal(scales1_fused, scales1_blocked_ref, N, M // 32)
976+
977+
888978
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
889979
@pytest.mark.parametrize(
890980
"m,k",

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@
1919
triton_mxfp8_dispatch_and_quantize, # noqa: F401
2020
triton_mxfp8_pad_and_quantize, # noqa: F401
2121
)
22+
from torchao.prototype.moe_training.kernels.mxfp8.triton_grad_quantize import (
23+
triton_mxfp8_quantize_dim0_dim1, # noqa: F401
24+
)

0 commit comments

Comments
 (0)