Skip to content

Commit ae6cdd1

Browse files
Liclaude
andcommitted
[moe training] add benchmark for dual colwise FP8 scales kernel
Benchmarks triton_fp8_per_group_colwise_scales_dual vs two sequential triton_fp8_per_group_colwise_scales calls, across representative MoE backward shapes. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent a7529f5 commit ae6cdd1

1 file changed

Lines changed: 185 additions & 0 deletions

File tree

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Benchmarks triton_fp8_per_group_colwise_scales_dual vs two sequential
8+
# triton_fp8_per_group_colwise_scales calls, mirroring the MoE backward pass
9+
# where both padded_grad_output and padded_A are quantized colwise.
10+
11+
import itertools
12+
from dataclasses import dataclass
13+
from typing import List
14+
15+
import torch
16+
from tabulate import tabulate
17+
from tqdm import tqdm
18+
from triton.testing import do_bench
19+
20+
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
21+
triton_fp8_per_group_colwise_scales,
22+
triton_fp8_per_group_colwise_scales_dual,
23+
)
24+
from torchao.prototype.moe_training.utils import generate_jagged_offs
25+
26+
device = torch.device("cuda")
27+
28+
29+
@dataclass(frozen=True)
30+
class ExperimentConfig:
31+
high_precision_dtype: torch.dtype
32+
M: int # total tokens (rows shared by both tensors)
33+
N1: int # cols of tensor 1 (grad_output hidden dim)
34+
N2: int # cols of tensor 2 (A hidden dim)
35+
n_groups: int
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
two_calls_time_us: float
41+
dual_time_us: float
42+
speedup: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
# Representative MoE backward shapes:
53+
# M = total padded tokens across all experts
54+
# N1 = grad_output hidden dim (output of expert)
55+
# N2 = A hidden dim (input to expert)
56+
# n_groups = num experts
57+
shapes = [
58+
# (M, N1, N2, n_groups)
59+
(16640, 2048, 2048, 64),
60+
(16640, 5120, 2048, 64),
61+
(16640, 5120, 5120, 64),
62+
(16640, 8192, 2048, 64),
63+
(16640, 2048, 2048, 128),
64+
(16640, 5120, 2048, 128),
65+
(16640, 5120, 5120, 128),
66+
(32768, 5120, 5120, 128),
67+
]
68+
configs = []
69+
for (M, N1, N2, n_groups), dtype in itertools.product(shapes, [torch.bfloat16]):
70+
configs.append(
71+
ExperimentConfig(
72+
high_precision_dtype=dtype,
73+
M=M,
74+
N1=N1,
75+
N2=N2,
76+
n_groups=n_groups,
77+
)
78+
)
79+
return configs
80+
81+
82+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
83+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
84+
85+
86+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
87+
dtype = config.high_precision_dtype
88+
offs = generate_jagged_offs(config.n_groups, config.M, multiple_of=16)
89+
90+
# tensor 1: padded_grad_output shape (M, N1), row-major
91+
t1 = torch.randn(config.M, config.N1, dtype=dtype, device=device)
92+
# tensor 2: padded_A shape (M, N2), row-major
93+
t2 = torch.randn(config.M, config.N2, dtype=dtype, device=device)
94+
95+
fp8_dtype = torch.float8_e4m3fn
96+
97+
# --- Baseline: two sequential calls (before optimization) ---
98+
def run_two_calls():
99+
triton_fp8_per_group_colwise_scales(
100+
t1, offs, output_dtype=fp8_dtype, round_scales_to_power_of_2=True
101+
)
102+
triton_fp8_per_group_colwise_scales(
103+
t2, offs, output_dtype=fp8_dtype, round_scales_to_power_of_2=True
104+
)
105+
106+
# --- Optimized: single dual call ---
107+
def run_dual():
108+
triton_fp8_per_group_colwise_scales_dual(
109+
t1, t2, offs, output_dtype=fp8_dtype, round_scales_to_power_of_2=True
110+
)
111+
112+
# Warmup
113+
for _ in range(10):
114+
run_two_calls()
115+
for _ in range(10):
116+
run_dual()
117+
118+
two_calls_time_us = benchmark_cuda_function_in_microseconds(run_two_calls)
119+
dual_time_us = benchmark_cuda_function_in_microseconds(run_dual)
120+
121+
return ExperimentResult(
122+
two_calls_time_us=two_calls_time_us,
123+
dual_time_us=dual_time_us,
124+
speedup=two_calls_time_us / dual_time_us,
125+
)
126+
127+
128+
def print_results(experiments: List[Experiment]):
129+
headers = [
130+
"M",
131+
"N1",
132+
"N2",
133+
"n_groups",
134+
"dtype",
135+
"two calls (us)",
136+
"dual (us)",
137+
"speedup",
138+
]
139+
rows = []
140+
for e in experiments:
141+
c, r = e.config, e.result
142+
rows.append(
143+
[
144+
c.M,
145+
c.N1,
146+
c.N2,
147+
c.n_groups,
148+
str(c.high_precision_dtype).split(".")[-1],
149+
f"{r.two_calls_time_us:.1f}",
150+
f"{r.dual_time_us:.1f}",
151+
f"{r.speedup:.2f}x",
152+
]
153+
)
154+
print(tabulate(rows, headers=headers))
155+
print()
156+
speedups = [e.result.speedup for e in experiments]
157+
print(
158+
f"dual vs two calls — avg: {sum(speedups) / len(speedups):.2f}x "
159+
f"min: {min(speedups):.2f}x max: {max(speedups):.2f}x"
160+
)
161+
162+
163+
def main():
164+
torch.random.manual_seed(123)
165+
configs = get_configs()
166+
results = []
167+
for config in tqdm(configs):
168+
result = run_experiment(config)
169+
results.append(Experiment(config=config, result=result))
170+
171+
print()
172+
print("=" * 70)
173+
print("Dual Colwise FP8 Scales Kernel Benchmark")
174+
print()
175+
print(" two calls : triton_fp8_per_group_colwise_scales called twice")
176+
print(" (baseline — backward pass before optimization)")
177+
print(" dual : triton_fp8_per_group_colwise_scales_dual — single launch")
178+
print(" merges row loops for both tensors")
179+
print("=" * 70)
180+
print()
181+
print_results(results)
182+
183+
184+
if __name__ == "__main__":
185+
main()

0 commit comments

Comments
 (0)