Skip to content

Commit 4cc0095

Browse files
committed
Update
[ghstack-poisoned]
1 parent c18ec8e commit 4cc0095

7 files changed

Lines changed: 328 additions & 204 deletions

File tree

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
class ExperimentConfig:
3232
input_shape: tuple[int]
3333
scaling_mode: ScaleCalculationMode
34-
scale_block_k: int
34+
variant: str
3535

3636

3737
@dataclass(frozen=True)
@@ -63,41 +63,60 @@ def get_configs() -> List[ExperimentConfig]:
6363
(32, 8192, 5120),
6464
]
6565
round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL]
66-
scale_block_ks = [1, 32]
66+
variants = ["32x1_t", "32x1_n", "32x32_n"]
6767
configs = []
68-
for shape, scaling_mode, scale_block_k in itertools.product(
69-
input_shapes, round_modes, scale_block_ks
68+
for shape, scaling_mode, variant in itertools.product(
69+
input_shapes, round_modes, variants
7070
):
7171
configs.append(
7272
ExperimentConfig(
7373
input_shape=shape,
7474
scaling_mode=scaling_mode,
75-
scale_block_k=scale_block_k,
75+
variant=variant,
7676
)
7777
)
7878
return configs
7979

8080

8181
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8282
block_size = 32
83-
scale_block_k = config.scale_block_k
83+
variant = config.variant
8484
input_shape = config.input_shape
8585
input_tensor = torch.randn(
8686
*input_shape,
8787
dtype=torch.bfloat16,
8888
device=device,
8989
)
9090

91+
def get_quant_input(x: torch.Tensor) -> torch.Tensor:
92+
# The "32x1_t" benchmark row is the reviewer-requested
93+
# contract: feed (E, K, N) K-major expert weights directly into the
94+
# existing 3D 32x1 kernel.
95+
if variant == "32x1_t":
96+
return x.transpose(-2, -1)
97+
return x
98+
9199
def using_to_mx(x: torch.Tensor) -> torch.Tensor:
92-
if scale_block_k == 1:
100+
if variant == "32x1_t":
101+
x_t = x.transpose(-2, -1)
102+
s_ref, y_ref = to_mx(
103+
x_t.transpose(-2, -1).contiguous(),
104+
elem_dtype=torch.float8_e4m3fn,
105+
block_size=block_size,
106+
scaling_mode=config.scaling_mode,
107+
)
108+
return y_ref.transpose(-2, -1), s_ref.transpose(-2, -1)
109+
110+
if variant == "32x1_n":
93111
s_ref, y_ref = to_mx(
94112
x.transpose(-2, -1).contiguous(),
95113
elem_dtype=torch.float8_e4m3fn,
96114
block_size=block_size,
115+
scaling_mode=config.scaling_mode,
97116
)
98117
return y_ref.transpose(-2, -1), s_ref.transpose(-2, -1)
99118

100-
assert scale_block_k == 32
119+
assert variant == "32x32_n"
101120
E, N, K = x.shape
102121
x_tiles = (
103122
x.view(E, N // block_size, block_size, K // block_size, block_size)
@@ -109,6 +128,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
109128
x_tiles,
110129
elem_dtype=torch.float8_e4m3fn,
111130
block_size=block_size * block_size,
131+
scaling_mode=config.scaling_mode,
112132
)
113133
y_ref = (
114134
y_tiles_ref.view(
@@ -129,7 +149,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
129149
input_tensor,
130150
)
131151

132-
if scale_block_k == 1:
152+
if variant == "32x1_n":
133153
# bench 2d dim1 kernel then transforming to col major
134154
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
135155
using_cuda_2d_c(input_tensor)
@@ -142,19 +162,23 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
142162
else:
143163
time_cuda_2d_us = float("nan")
144164

165+
quant_input = get_quant_input(input_tensor)
166+
scale_block_n = block_size
167+
scale_block_k = 1 if variant in ("32x1_t", "32x1_n") else block_size
168+
145169
# bench 3d CuTeDSL kernel
146170
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(
147-
input_tensor,
171+
quant_input,
148172
block_size=block_size,
149-
scale_block_n=block_size,
173+
scale_block_n=scale_block_n,
150174
scale_block_k=scale_block_k,
151175
scaling_mode=str(config.scaling_mode.value),
152176
)
153177
time_cutedsl_3d_us = benchmark_cuda_function_in_microseconds(
154178
mxfp8_quantize_cuda_3d,
155-
input_tensor,
179+
quant_input,
156180
block_size=block_size,
157-
scale_block_n=block_size,
181+
scale_block_n=scale_block_n,
158182
scale_block_k=scale_block_k,
159183
scaling_mode=str(config.scaling_mode.value),
160184
)
@@ -164,7 +188,7 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
164188
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
165189
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8
166190

167-
read_bytes = input_tensor.numel() * bytes_per_input_el
191+
read_bytes = quant_input.numel() * bytes_per_input_el
168192
write_bytes = (
169193
data_cuda_3d.numel() * bytes_per_output_el
170194
+ scales_cuda_3d.numel() * bytes_per_scale_el
@@ -189,7 +213,7 @@ def print_results(experiments: List[Experiment]):
189213
headers = [
190214
"input_shape",
191215
"scaling_mode",
192-
"scale_block_k",
216+
"variant",
193217
"cuda_2d_us",
194218
"cutedsl_3d_us",
195219
"to_mx_us",
@@ -203,7 +227,7 @@ def print_results(experiments: List[Experiment]):
203227
[
204228
str(experiment.config.input_shape),
205229
str(experiment.config.scaling_mode),
206-
str(experiment.config.scale_block_k),
230+
experiment.config.variant,
207231
experiment.result.cuda_2d_us,
208232
experiment.result.cutedsl_3d_us,
209233
experiment.result.to_mx_us,

benchmarks/prototype/moe_training/mxfp8/roofline_unified.py

Lines changed: 25 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from torchao.prototype.moe_training.kernels.mxfp8 import (
2222
mx_block_rearrange_2d_M_groups_cuda,
2323
torch_to_blocked_2d_M_groups,
24-
torch_to_blocked_per_group_3d,
2524
triton_mx_block_rearrange_2d_K_groups,
2625
triton_mx_block_rearrange_per_group_3d,
2726
)
@@ -238,33 +237,27 @@ def compute_mxfp8_2d_2d_gemm_time(self, N, M, K):
238237
return time_s
239238

240239
def compute_mxfp8_fwd_bwd_time(self, M, K, N, G):
241-
"""Compute time for MXFP8 forward + backward pass including scale rearrangement overhead"""
240+
"""Compute time for MXFP8 forward + backward pass."""
242241
block_size = 32
243242

244243
# Forward: (M, K) @ (G, K, N)^T -> (M, N) [2D-3D]
245244
fwd_quant_time = self.compute_mxfp8_fwd_quant_time(M, K, G, N)
246245
# Forward scale rearrangement:
247246
# - Input scales (M, K//32) -> M-groups rearrangement
248-
# - Weight scales (G, N, K//32) -> 3D per-group rearrangement
247+
# - Weight scales are emitted directly in blocked layout by the 3D kernel
249248
fwd_input_scale_rearrange_time = self.compute_rearrange_2d_M_groups_time(
250249
M, K // block_size
251250
)
252-
fwd_weight_scale_rearrange_time = self.compute_rearrange_3d_per_group_time(
253-
G, N, K // block_size
254-
)
255251
fwd_gemm_time = self.compute_mxfp8_2d_3d_gemm_time(M, K, N)
256252

257253
# Backward input: (M, N) @ (G, N, K) -> (M, K) [2D-3D]
258254
bwd_input_quant_time = self.compute_mxfp8_bwd_input_quant_time(M, K, G, N)
259255
# Backward input scale rearrangement:
260256
# - grad_output scales (M, N//32) -> M-groups rearrangement
261-
# - Weight scales (G, K, N//32) -> 3D per-group rearrangement (transposed weight)
257+
# - Weight scales are emitted directly in blocked layout by the 3D kernel
262258
bwd_input_grad_scale_rearrange_time = self.compute_rearrange_2d_M_groups_time(
263259
M, N // block_size
264260
)
265-
bwd_input_weight_scale_rearrange_time = (
266-
self.compute_rearrange_3d_per_group_time(G, K, N // block_size)
267-
)
268261
bwd_input_gemm_time = self.compute_mxfp8_2d_3d_gemm_time(M, N, K)
269262

270263
# Backward weight: (N, M) @ (M, K) -> G separate (N, K) [2D-2D]
@@ -283,11 +276,9 @@ def compute_mxfp8_fwd_bwd_time(self, M, K, N, G):
283276
total_time = (
284277
fwd_quant_time
285278
+ fwd_input_scale_rearrange_time
286-
+ fwd_weight_scale_rearrange_time
287279
+ fwd_gemm_time
288280
+ bwd_input_quant_time
289281
+ bwd_input_grad_scale_rearrange_time
290-
+ bwd_input_weight_scale_rearrange_time
291282
+ bwd_input_gemm_time
292283
+ bwd_weight_quant_time
293284
+ bwd_weight_grad_scale_rearrange_time
@@ -441,8 +432,6 @@ def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels):
441432
x_clone = x.clone().requires_grad_(True)
442433
w_t_clone = w_t.clone().requires_grad_(True)
443434

444-
fn = torch.compile(_to_mxfp8_then_scaled_grouped_mm, fullgraph=True)
445-
446435
# Set all parameters explicitly as variables for positional args
447436
A = x_clone
448437
B_t = w_t_clone
@@ -453,7 +442,7 @@ def benchmark_mxfp8_grouped_mm_fwd_bwd(x, w_t, offs, labels):
453442
scale_calculation_mode = MoEScaleCalculationMode.RCEIL
454443

455444
def wrapper():
456-
out = fn(
445+
out = _to_mxfp8_then_scaled_grouped_mm(
457446
A,
458447
B_t,
459448
offs_arg,
@@ -492,7 +481,7 @@ def benchmark_to_mxfp8_dim1_cuda(tensor, block_size=32):
492481

493482

494483
def benchmark_mxfp8_quantize_cuda_3d(tensor, block_size=32):
495-
"""Benchmark mxfp8_quantize_cuda_3d kernel"""
484+
"""Benchmark the 3D 32x1 quantizer on its input tensor."""
496485
return benchmark_cuda_function_in_microseconds(
497486
lambda: mxfp8_quantize_cuda_3d(
498487
tensor,
@@ -715,7 +704,7 @@ def run(
715704
# 3. 3D Quantization Kernel Analysis
716705
# =============================================================================
717706
print("\n" + "=" * 80)
718-
print("3D QUANTIZATION KERNELS (Backward Pass - Weight Quantization)")
707+
print("3D QUANTIZATION KERNELS (Direct Transposed-Weight Quantization)")
719708
print("=" * 80)
720709

721710
quant_3d_results = []
@@ -741,8 +730,10 @@ def run(
741730

742731
print(f"\nBenchmarking {desc}...")
743732

744-
# Create test tensor
745-
tensor = torch.randn(G_val, N_val, K_val, dtype=torch.bfloat16, device="cuda")
733+
# Benchmark the direct grouped-GEMM weight contract: w_t has shape
734+
# (G, K, N), and the existing 3D 32x1 kernel quantizes it directly.
735+
weight = torch.randn(G_val, N_val, K_val, dtype=torch.bfloat16, device="cuda")
736+
tensor = weight.transpose(-2, -1)
746737

747738
# Benchmark mxfp8_quantize_cuda_3d
748739
cuda_3d_time_us = benchmark_mxfp8_quantize_cuda_3d(tensor)
@@ -1030,21 +1021,25 @@ def run(
10301021
f" BF16 Grouped GEMM: Roofline={model.bf16_tflops:.1f} TFLOPS, Actual={bf16_actual_tflops:.1f} TFLOPS, Efficiency={result_dict['bf16_tflops_efficiency_pct']:.1f}%"
10311022
)
10321023

1033-
# Convert to MXFP8 format using triton_to_mxfp8_dim0 (blocks along dim0)
1024+
# Convert activations to MXFP8 format using triton_to_mxfp8_dim0
10341025
x_fp8, x_scales = triton_to_mxfp8_dim0(x, inner_block_size=32)
1035-
w_fp8, w_scales = triton_to_mxfp8_dim0(
1036-
w_t.transpose(-2, -1), inner_block_size=32
1026+
w_fp8, w_scales_blocked = mxfp8_quantize_cuda_3d(
1027+
w_t,
1028+
block_size=32,
1029+
scale_block_n=32,
1030+
scale_block_k=1,
1031+
scaling_mode="rceil",
10371032
)
10381033

1039-
# Convert scales to blocked format
1034+
# Convert only activation scales to blocked format. Weight scales are
1035+
# already produced in blocked layout by mxfp8_quantize_cuda_3d.
10401036
x_scales_blocked, _ = torch_to_blocked_2d_M_groups(
10411037
x_scales, offs, block_size=32
10421038
)
1043-
w_scales_blocked = torch_to_blocked_per_group_3d(w_scales)
10441039

10451040
# Benchmark the MXFP8 grouped GEMM kernel
10461041
mxfp8_gemm_time_us = benchmark_mxfp8_grouped_gemm(
1047-
x_fp8, w_fp8.transpose(-2, -1), x_scales_blocked, w_scales_blocked, offs
1042+
x_fp8, w_fp8, x_scales_blocked, w_scales_blocked, offs
10481043
)
10491044

10501045
# Calculate MXFP8 actual TFLOPS
@@ -1068,7 +1063,7 @@ def run(
10681063
grouped_gemm_results.append(result_dict)
10691064

10701065
# Clean up tensors to free GPU memory
1071-
del x, w, w_t, offs, x_fp8, x_scales, w_fp8, w_scales
1066+
del x, w, w_t, offs, x_fp8, x_scales, w_fp8
10721067
del x_scales_blocked, w_scales_blocked
10731068
torch.cuda.empty_cache()
10741069

@@ -1436,7 +1431,8 @@ def run(
14361431
# Input quantization: use triton_to_mxfp8_dim0 for (M, K)
14371432
fwd_input_quant_ms = df_quant_2d.loc[idx_large, "triton_to_mxfp8_dim0_us"] / 1000
14381433

1439-
# Weight quantization: use mxfp8_quantize_cuda_3d for (G, N, K)
1434+
# Weight quantization: use mxfp8_quantize_cuda_3d directly on w_t, shape
1435+
# (G, K, N), with no separate 3D scale rearrangement step.
14401436
idx_3d_large = df_quant_3d[df_quant_3d["description"] == f"M={M_large}"].index[0]
14411437
fwd_weight_quant_ms = (
14421438
df_quant_3d.loc[idx_3d_large, "mxfp8_quantize_cuda_3d_us"] / 1000
@@ -1450,14 +1446,8 @@ def run(
14501446
df_rearrange.loc[idx_m_groups, "mx_block_rearrange_2d_M_groups_cuda_us"] / 1000
14511447
)
14521448

1453-
# Weight scale rearrangement: 3D per-group for (G, N, K//32)
1454-
idx_3d_rearrange = df_rearrange_3d[df_rearrange_3d["M"] == M_large].index[0]
1455-
fwd_weight_scale_rearrange_ms = (
1456-
df_rearrange_3d.loc[
1457-
idx_3d_rearrange, "triton_mx_block_rearrange_per_group_3d_us"
1458-
]
1459-
/ 1000
1460-
)
1449+
# Weight scales are emitted directly in blocked layout by the 3D quantizer.
1450+
fwd_weight_scale_rearrange_ms = 0.0
14611451

14621452
# GEMM: use actual MXFP8 2D/3D grouped GEMM time
14631453
idx_gemm = df_grouped_gemm[df_grouped_gemm["M"] == M_large].index[0]

0 commit comments

Comments
 (0)