Skip to content

Commit 707bee8

Browse files
[mxfp8 training] add cutedsl kernel for 32x1 mxfp8 quantization on 2d tensors (#4239)
* cutedsl mxfp8 dim1 kernel * work on blocked layout * integrate new kernel, row major scales * temporarily support both cuda and cutedsl * blocked layout working * blocked layout bench script * linear tests passing * add bench script comparing to cuda * cleanup * make naming consistent * add test skips baed on sm version and kernel availability
1 parent e94f638 commit 707bee8

File tree

11 files changed

+1597
-31
lines changed

11 files changed

+1597
-31
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 84 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ def run(
128128
"dim1_mxfp8_cuda_rceil",
129129
"dim0_mxfp8_cutedsl_2d_floor",
130130
"dim0_mxfp8_cutedsl_2d_rceil",
131+
"dim1_mxfp8_cutedsl_2d_floor",
132+
"dim1_mxfp8_cutedsl_2d_rceil",
131133
)
132134

133135
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
@@ -474,17 +476,21 @@ def run(
474476
bps = (bytes_r + bytes_w) / (time_us / 1e6)
475477

476478
elif mode == "dim0_mxfp8_cutedsl_2d_floor":
477-
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
479+
from torchao.prototype.moe_training.kernels.mxfp8 import (
480+
mxfp8_quantize_2d_1x32_cutedsl,
481+
)
478482

479-
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
483+
y_d0, s_d0 = mxfp8_quantize_2d_1x32_cutedsl(
480484
x, block_size=BLOCK_SIZE, scaling_mode="floor"
481485
)
482486

483487
for _ in range(2):
484-
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="floor")
488+
__ = mxfp8_quantize_2d_1x32_cutedsl(
489+
x, block_size=BLOCK_SIZE, scaling_mode="floor"
490+
)
485491

486492
time_us = benchmark_cuda_function_in_microseconds(
487-
lambda x: mxfp8_quantize_cuda_2d(
493+
lambda x: mxfp8_quantize_2d_1x32_cutedsl(
488494
x, block_size=BLOCK_SIZE, scaling_mode="floor"
489495
),
490496
x,
@@ -498,17 +504,21 @@ def run(
498504
bps = (bytes_r + bytes_w) / (time_us / 1e6)
499505

500506
elif mode == "dim0_mxfp8_cutedsl_2d_rceil":
501-
from torchao.prototype.moe_training.kernels.mxfp8 import mxfp8_quantize_cuda_2d
507+
from torchao.prototype.moe_training.kernels.mxfp8 import (
508+
mxfp8_quantize_2d_1x32_cutedsl,
509+
)
502510

503-
y_d0, s_d0 = mxfp8_quantize_cuda_2d(
511+
y_d0, s_d0 = mxfp8_quantize_2d_1x32_cutedsl(
504512
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
505513
)
506514

507515
for _ in range(2):
508-
__ = mxfp8_quantize_cuda_2d(x, block_size=BLOCK_SIZE, scaling_mode="rceil")
516+
__ = mxfp8_quantize_2d_1x32_cutedsl(
517+
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
518+
)
509519

510520
time_us = benchmark_cuda_function_in_microseconds(
511-
lambda x: mxfp8_quantize_cuda_2d(
521+
lambda x: mxfp8_quantize_2d_1x32_cutedsl(
512522
x, block_size=BLOCK_SIZE, scaling_mode="rceil"
513523
),
514524
x,
@@ -517,10 +527,76 @@ def run(
517527
assert y_d0.dtype == torch.float8_e4m3fn
518528
assert s_d0.dtype == torch.float8_e8m0fnu
519529

530+
bytes_r = x.numel() * bytes_per_el_bf16
531+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
532+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
533+
elif mode == "dim1_mxfp8_cutedsl_2d_floor":
534+
from torchao.prototype.moe_training.kernels.mxfp8 import (
535+
mxfp8_quantize_2d_32x1_cutedsl,
536+
)
537+
538+
y_d0, s_d0 = mxfp8_quantize_2d_32x1_cutedsl(
539+
x, block_size=BLOCK_SIZE, scaling_mode="floor", blocked_scale_output=True
540+
)
541+
542+
for _ in range(2):
543+
__ = mxfp8_quantize_2d_32x1_cutedsl(
544+
x,
545+
block_size=BLOCK_SIZE,
546+
scaling_mode="floor",
547+
blocked_scale_output=True,
548+
)
549+
550+
time_us = benchmark_cuda_function_in_microseconds(
551+
lambda x: mxfp8_quantize_2d_32x1_cutedsl(
552+
x,
553+
block_size=BLOCK_SIZE,
554+
scaling_mode="floor",
555+
blocked_scale_output=True,
556+
),
557+
x,
558+
)
559+
560+
assert y_d0.dtype == torch.float8_e4m3fn
561+
assert s_d0.dtype == torch.float8_e8m0fnu
562+
520563
bytes_r = x.numel() * bytes_per_el_bf16
521564
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
522565
bps = (bytes_r + bytes_w) / (time_us / 1e6)
523566

567+
elif mode == "dim1_mxfp8_cutedsl_2d_rceil":
568+
from torchao.prototype.moe_training.kernels.mxfp8 import (
569+
mxfp8_quantize_2d_32x1_cutedsl,
570+
)
571+
572+
y_d0, s_d0 = mxfp8_quantize_2d_32x1_cutedsl(
573+
x, block_size=BLOCK_SIZE, scaling_mode="rceil", blocked_scale_output=True
574+
)
575+
576+
for _ in range(2):
577+
__ = mxfp8_quantize_2d_32x1_cutedsl(
578+
x,
579+
block_size=BLOCK_SIZE,
580+
scaling_mode="rceil",
581+
blocked_scale_output=True,
582+
)
583+
584+
time_us = benchmark_cuda_function_in_microseconds(
585+
lambda x: mxfp8_quantize_2d_32x1_cutedsl(
586+
x,
587+
block_size=BLOCK_SIZE,
588+
scaling_mode="rceil",
589+
blocked_scale_output=True,
590+
),
591+
x,
592+
)
593+
594+
assert y_d0.dtype == torch.float8_e4m3fn
595+
assert s_d0.dtype == torch.float8_e8m0fnu
596+
597+
bytes_r = x.numel() * bytes_per_el_bf16
598+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
599+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
524600
else:
525601
raise AssertionError(f"unknown mode {mode}")
526602

benchmarks/prototype/moe_training/mxfp8/bench_cutedsl_quantize_2d.py renamed to benchmarks/prototype/moe_training/mxfp8/bench_cutedsl_quantize_2d_1x32.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from torchao.prototype.moe_training.kernels.mxfp8 import (
1818
mx_block_rearrange_2d_M_groups_cuda,
1919
)
20-
from torchao.prototype.moe_training.kernels.mxfp8.cutedsl_quantize_2d import (
21-
mxfp8_quantize_cutedsl_2d,
20+
from torchao.prototype.moe_training.kernels.mxfp8.cutedsl_quantize_2d_1x32 import (
21+
mxfp8_quantize_cutedsl_2d_1x32,
2222
)
2323
from torchao.prototype.moe_training.utils import generate_jagged_offs
2424
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
@@ -99,14 +99,14 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
9999
)
100100

101101
# Benchmark 1: CuTeDSL kernel with blocked scale output
102-
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d(
102+
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d_1x32(
103103
input_tensor,
104104
block_size=block_size,
105105
scaling_mode=scaling_mode,
106106
blocked_scale_output=True,
107107
)
108108
cutedsl_blocked_time_us = benchmark_cuda_function_in_microseconds(
109-
mxfp8_quantize_cutedsl_2d,
109+
mxfp8_quantize_cutedsl_2d_1x32,
110110
input_tensor,
111111
block_size=block_size,
112112
scaling_mode=scaling_mode,
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
16+
from benchmarks.utils import benchmark_cuda_function_in_microseconds
17+
from torchao.prototype.moe_training.kernels.mxfp8.cutedsl_quantize_2d_32x1 import (
18+
mxfp8_quantize_cutedsl_2d_32x1,
19+
)
20+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
21+
triton_mx_block_rearrange_2d_K_groups,
22+
)
23+
from torchao.prototype.moe_training.utils import generate_jagged_offs
24+
from torchao.prototype.mx_formats.kernels import mxfp8_quantize_cuda
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
input_shape: tuple[int, int]
35+
scaling_mode: str
36+
num_groups: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
# time
42+
cutedsl_blocked_us: float
43+
cuda_plus_rearrange_us: float
44+
# mem bw
45+
cutedsl_blocked_gbps: float
46+
cuda_plus_rearrange_gbps: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
input_shapes = [
57+
# DeepSeekV3 671b shapes
58+
(8192, 2048),
59+
(8192, 7168),
60+
(32768, 2048),
61+
(32768, 7168),
62+
(131072, 2048),
63+
(131072, 7168),
64+
]
65+
scaling_modes = ["rceil"]
66+
num_groups_list = [4, 8]
67+
configs = []
68+
for shape, scaling_mode, num_groups in itertools.product(
69+
input_shapes, scaling_modes, num_groups_list
70+
):
71+
configs.append(
72+
ExperimentConfig(
73+
input_shape=shape,
74+
scaling_mode=scaling_mode,
75+
num_groups=num_groups,
76+
)
77+
)
78+
return configs
79+
80+
81+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
82+
block_size = 32
83+
input_shape = config.input_shape
84+
scaling_mode = config.scaling_mode
85+
num_groups = config.num_groups
86+
87+
input_tensor = torch.randn(
88+
*input_shape,
89+
dtype=torch.bfloat16,
90+
device=device,
91+
)
92+
93+
M, K = input_shape
94+
95+
# Generate jagged offsets with multiples of 128 for K dimension
96+
# TODO: we use multiple of 128 here to avoid per-group padding requirement in blocked scales layout, which cutedsl doesn't support yet.
97+
group_end_offsets = generate_jagged_offs(
98+
num_groups, K, multiple_of=128, device=device
99+
)
100+
101+
# Benchmark 1: CuTeDSL kernel with blocked scale output
102+
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d_32x1(
103+
input_tensor,
104+
block_size=block_size,
105+
scaling_mode=scaling_mode,
106+
blocked_scale_output=True,
107+
)
108+
cutedsl_blocked_time_us = benchmark_cuda_function_in_microseconds(
109+
mxfp8_quantize_cutedsl_2d_32x1,
110+
input_tensor,
111+
block_size=block_size,
112+
scaling_mode=scaling_mode,
113+
blocked_scale_output=True,
114+
)
115+
116+
# Benchmark 2: CUDA quantization + CUDA scale rearrangement
117+
def cuda_plus_rearrange(x, group_offs):
118+
# Quantize with 32x1 scaling (rowwise=True, colwise=False)
119+
_, output_colwise, _, scales_colwise = mxfp8_quantize_cuda(
120+
x,
121+
rowwise=False,
122+
colwise=True,
123+
scaling_mode=scaling_mode,
124+
)
125+
# Convert scales to blocked layout for K groups
126+
scales_blocked = triton_mx_block_rearrange_2d_K_groups(
127+
scales_colwise.view(torch.uint8), group_offs // 32
128+
)
129+
return output_colwise, scales_blocked
130+
131+
data_cuda, scales_cuda = cuda_plus_rearrange(input_tensor, group_end_offsets)
132+
cuda_plus_rearrange_time_us = benchmark_cuda_function_in_microseconds(
133+
cuda_plus_rearrange,
134+
input_tensor,
135+
group_end_offsets,
136+
)
137+
138+
# Memory bandwidth calculations
139+
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
140+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
141+
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
142+
143+
read_bytes = input_tensor.numel() * bytes_per_input_el
144+
write_bytes = (
145+
data_cutedsl.numel() * bytes_per_output_el
146+
+ scales_cutedsl.numel() * bytes_per_scale_el
147+
)
148+
149+
cutedsl_blocked_gbps = ((read_bytes + write_bytes) / 1e9) / (
150+
cutedsl_blocked_time_us / 1e6
151+
)
152+
cuda_plus_rearrange_gbps = ((read_bytes + write_bytes) / 1e9) / (
153+
cuda_plus_rearrange_time_us / 1e6
154+
)
155+
156+
return ExperimentResult(
157+
cutedsl_blocked_us=cutedsl_blocked_time_us,
158+
cuda_plus_rearrange_us=cuda_plus_rearrange_time_us,
159+
cutedsl_blocked_gbps=cutedsl_blocked_gbps,
160+
cuda_plus_rearrange_gbps=cuda_plus_rearrange_gbps,
161+
)
162+
163+
164+
def print_results(experiments: List[Experiment]):
165+
headers = [
166+
"input_shape",
167+
"scaling_mode",
168+
"num_groups",
169+
"cutedsl_blocked_us",
170+
"cuda+rearrange_us",
171+
"speedup",
172+
"cutedsl_gbps",
173+
"cuda+rearrange_gbps",
174+
]
175+
rows = []
176+
for experiment in experiments:
177+
speedup = (
178+
experiment.result.cuda_plus_rearrange_us
179+
/ experiment.result.cutedsl_blocked_us
180+
)
181+
rows.append(
182+
[
183+
str(experiment.config.input_shape),
184+
experiment.config.scaling_mode,
185+
experiment.config.num_groups,
186+
f"{experiment.result.cutedsl_blocked_us:.2f}",
187+
f"{experiment.result.cuda_plus_rearrange_us:.2f}",
188+
f"{speedup:.2f}x",
189+
f"{experiment.result.cutedsl_blocked_gbps:.1f}",
190+
f"{experiment.result.cuda_plus_rearrange_gbps:.1f}",
191+
]
192+
)
193+
print(tabulate(rows, headers=headers))
194+
195+
196+
def main():
197+
torch.random.manual_seed(123)
198+
configs = get_configs()
199+
results = []
200+
for config in tqdm(configs):
201+
result = run_experiment(config)
202+
results.append(Experiment(config=config, result=result))
203+
204+
# Use Tabulate to print results
205+
print_results(results)
206+
207+
208+
if __name__ == "__main__":
209+
main()

0 commit comments

Comments
 (0)