Skip to content

Commit 7bb7f06

Browse files
lizamdLiclaude
authored
Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass (#3972)
Two interdependent changes that together yield ~4.3x end-to-end training throughput improvement on MI300X for DeepSeek-MoE-16B: 1. Remove redundant .t().contiguous().t() memory copies before calling triton_fp8_per_group_colwise_scales in the backward pass. The kernel already handles arbitrary strides via its stride parameters, so these full-tensor copies are unnecessary. 2. Use larger Triton autotune configs (BLOCK_SIZE=128/256, BLOCK_SIZE_ITER= 128/256) for the colwise scales kernel on AMD GPUs. With row-major input (from change 1), larger block sizes enable contiguous column access patterns, reducing grid block count by 4-8x. Benchmarked on 8x MI300X with DeepSeek-MoE-16B (EP=8, seq_len=4096): - Batch size 1: 136 TPS -> 642 TPS (4.7x) - Batch size 4: 500 TPS -> 2153 TPS (4.3x) Co-authored-by: Li <lizli102@ctr2-alola-login-01.amd.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 4ae435e commit 7bb7f06

2 files changed

Lines changed: 16 additions & 14 deletions

File tree

torchao/prototype/moe_training/fp8_grouped_mm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,9 +214,7 @@ def backward(ctx, grad_output: torch.Tensor):
214214
# needed for grad_B: grad_output_t @ A
215215
# Use transpose method to avoid uncoalesced memory accesses.
216216
grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
217-
grad_output.t()
218-
.contiguous()
219-
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
217+
grad_output,
220218
offs,
221219
float8_dtype,
222220
round_scales_to_power_of_2=True,
@@ -225,9 +223,7 @@ def backward(ctx, grad_output: torch.Tensor):
225223
grad_output_t_scales = grad_out_scales.t()
226224

227225
A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales(
228-
A.t()
229-
.contiguous()
230-
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
226+
A,
231227
offs,
232228
float8_dtype,
233229
round_scales_to_power_of_2=True,

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,20 @@
4141
if torch.version.hip is not None:
4242
kernel_configs_2D = [
4343
triton.Config(
44-
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},
45-
num_warps=warps,
46-
num_stages=stages,
47-
)
48-
for block_size in [32, 64]
49-
for block_size_iter in [64, 128]
50-
for warps in [4, 8]
51-
for stages in [2, 3]
44+
{"BLOCK_SIZE": 128, "BLOCK_SIZE_ITER": 128},
45+
num_warps=8,
46+
num_stages=2,
47+
),
48+
triton.Config(
49+
{"BLOCK_SIZE": 128, "BLOCK_SIZE_ITER": 256},
50+
num_warps=8,
51+
num_stages=2,
52+
),
53+
triton.Config(
54+
{"BLOCK_SIZE": 256, "BLOCK_SIZE_ITER": 128},
55+
num_warps=8,
56+
num_stages=2,
57+
),
5258
]
5359
else:
5460
kernel_configs_2D = [

0 commit comments

Comments
 (0)