Skip to content

Commit 4ae435e

Browse files
Expand Triton autotune configs for MoE FP8 kernels to improve AMD GPU performance (#3952)
* Expand Triton autotune configs for MoE FP8 rowwise and jagged scales kernels The existing autotune configs for the MoE training FP8 kernels use a single configuration each (e.g., num_warps=4, num_stages=4, one block size), which prevents Triton's autotuner from finding better configs for different hardware targets. Expand the search space to cover: - Multiple num_warps values (4, 8) to better saturate both NVIDIA (warp size 32) and AMD (wavefront size 64) GPU compute units - Multiple num_stages values for software pipelining flexibility across different cache hierarchies - Multiple block sizes to adapt to varying matrix dimensions This is complementary to PR #3945 (relaxed atomics on AMDGPU) and targets the same kernels. * Gate expanded autotune configs to AMD only, preserve original NVIDIA configs H100 benchmarks showed ~18% regression on the atomic kernel with the expanded search space. The autotuner appears to pick suboptimal configs from the larger candidate set on NVIDIA. Gate the expanded configs behind torch.version.hip so AMD gets the broader search (4-7% faster on MI250X) while NVIDIA keeps the original tuned configs. * Widen autotune search space and add N_GROUPS to scales kernel autotuning key Two improvements based on MI300X (gfx942) benchmarking: 1. float8_rowwise.py: Widen block size search space for AMD GPUs. - Atomic configs: add BLOCK_SIZE_N=256 and BLOCK_SIZE_K=64 - Reduction configs: add BLOCK_SIZE_N=128, BLOCK_SIZE_K=64, and num_stages=2,4 - Yields 1.5-2.2x speedup on MI300X for the atomic kernel and 1.05-1.25x for the reduction kernel across Llama4 MoE shapes. 2. jagged_float8_scales.py: Add N_GROUPS to autotuning key for both rowwise and colwise scales kernels. The previous key (M or K only) caused the autotuner to cache a single config across all n_groups values, but optimal tile sizes differ significantly by n_groups. This eliminates cross-n_groups interference and allows each n_groups value to independently find its best config.
1 parent 9960da8 commit 4ae435e

2 files changed

Lines changed: 66 additions & 49 deletions

File tree

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,26 @@
3333
torch.float64: tl.float64,
3434
}
3535

36-
block_sizes_n = [128] # large dim (output_features)
37-
block_sizes_k = [128] # small dim (input_features)
38-
num_warps = [4]
39-
num_stages = [4]
40-
atomic_kernel_configs_2D = [
41-
triton.Config(
42-
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
43-
num_warps=warps,
44-
num_stages=stages,
45-
)
46-
for block_size_n in block_sizes_n
47-
for block_size_k in block_sizes_k
48-
for warps in num_warps
49-
for stages in num_stages
50-
]
36+
if torch.version.hip is not None:
37+
atomic_kernel_configs_2D = [
38+
triton.Config(
39+
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
40+
num_warps=warps,
41+
num_stages=stages,
42+
)
43+
for block_size_n in [64, 128, 256]
44+
for block_size_k in [64, 128]
45+
for warps in [4, 8]
46+
for stages in [2, 4]
47+
]
48+
else:
49+
atomic_kernel_configs_2D = [
50+
triton.Config(
51+
{"BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
52+
num_warps=4,
53+
num_stages=4,
54+
)
55+
]
5156

5257
@torch.library.custom_op(
5358
"torchao::triton_fp8_rowwise_transpose_rhs", mutates_args={}
@@ -283,23 +288,26 @@ def _triton_fp8_rowwise_3d_transpose_cast_rhs_kernel(
283288
output_mask = (n_offs[:, None] < N) & (k_offs[None, :] < K)
284289
tl.store(output_ptr + output_offs, output_data, mask=output_mask)
285290

286-
block_sizes_n = [
287-
64,
288-
] # large dim (output_features)
289-
block_sizes_k = [128] # small dim (input_features)
290-
num_warps = [8]
291-
num_stages = [6]
292-
reduction_kernel_configs_2D = [
293-
triton.Config(
294-
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
295-
num_warps=warps,
296-
num_stages=stages,
297-
)
298-
for block_size_n in block_sizes_n
299-
for block_size_k in block_sizes_k
300-
for warps in num_warps
301-
for stages in num_stages
302-
]
291+
if torch.version.hip is not None:
292+
reduction_kernel_configs_2D = [
293+
triton.Config(
294+
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
295+
num_warps=warps,
296+
num_stages=stages,
297+
)
298+
for block_size_n in [32, 64, 128]
299+
for block_size_k in [64, 128]
300+
for warps in [4, 8]
301+
for stages in [2, 4, 6]
302+
]
303+
else:
304+
reduction_kernel_configs_2D = [
305+
triton.Config(
306+
{"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128},
307+
num_warps=8,
308+
num_stages=6,
309+
)
310+
]
303311

304312
@triton.autotune(configs=reduction_kernel_configs_2D, key=["K", "N"])
305313
@triton.jit

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,26 @@
3838
torch.float64: tl.float64,
3939
}
4040

41-
block_sizes = [32] # [16, 32, 64]
42-
block_sizes_iter = [128] # [64, 128, 256]
43-
num_warps = [4]
44-
num_stages = [3]
45-
kernel_configs_2D = [
46-
triton.Config(
47-
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},
48-
num_warps=warps,
49-
num_stages=stages,
50-
)
51-
for block_size in block_sizes
52-
for block_size_iter in block_sizes_iter
53-
for warps in num_warps
54-
for stages in num_stages
55-
]
41+
if torch.version.hip is not None:
42+
kernel_configs_2D = [
43+
triton.Config(
44+
{"BLOCK_SIZE": block_size, "BLOCK_SIZE_ITER": block_size_iter},
45+
num_warps=warps,
46+
num_stages=stages,
47+
)
48+
for block_size in [32, 64]
49+
for block_size_iter in [64, 128]
50+
for warps in [4, 8]
51+
for stages in [2, 3]
52+
]
53+
else:
54+
kernel_configs_2D = [
55+
triton.Config(
56+
{"BLOCK_SIZE": 32, "BLOCK_SIZE_ITER": 128},
57+
num_warps=4,
58+
num_stages=3,
59+
)
60+
]
5661

5762
@torch.library.custom_op(
5863
"torchao::triton_fp8_per_group_rowwise_scales", mutates_args={}
@@ -108,6 +113,7 @@ def triton_fp8_per_group_rowwise_scales(
108113
scales_buffer,
109114
m,
110115
k,
116+
n_groups,
111117
hp_tensor.stride(0),
112118
hp_tensor.stride(1),
113119
output_buffer.stride(0),
@@ -147,7 +153,7 @@ def _fake_triton_fp8_per_group_rowwise_scales_kernel(
147153
# so the kernel is easily interpretable in a standalone fasion.
148154
# The tokens per expert will vary per iteration, so don't want
149155
# to recompile on `token` dim (K, in this case) changes.
150-
@triton.autotune(configs=kernel_configs_2D, key=["M"])
156+
@triton.autotune(configs=kernel_configs_2D, key=["M", "N_GROUPS"])
151157
@triton.jit
152158
def _triton_fp8_per_group_rowwise_scales_kernel(
153159
input_ptr,
@@ -156,6 +162,7 @@ def _triton_fp8_per_group_rowwise_scales_kernel(
156162
scales_ptr,
157163
M: int,
158164
K: int,
165+
N_GROUPS: int,
159166
stride_input_row: int,
160167
stride_input_col: int,
161168
stride_output_row: int,
@@ -299,6 +306,7 @@ def triton_fp8_per_group_colwise_scales(
299306
scales_buffer,
300307
k,
301308
n,
309+
n_groups,
302310
hp_tensor.stride(0),
303311
hp_tensor.stride(1),
304312
output_buffer.stride(0),
@@ -336,7 +344,7 @@ def _fake_triton_fp8_per_group_colwise_scales(
336344
# before the calculation `grad_B = grad_output_t @ input`.
337345
# The tokens per expert will vary per iteration, so don't want
338346
# to recompile on `token` dim (M) changes.
339-
@triton.autotune(configs=kernel_configs_2D, key=["K"])
347+
@triton.autotune(configs=kernel_configs_2D, key=["K", "N_GROUPS"])
340348
@triton.jit
341349
def _triton_fp8_per_group_colwise_scales_kernel(
342350
input_ptr,
@@ -345,6 +353,7 @@ def _triton_fp8_per_group_colwise_scales_kernel(
345353
scales_ptr,
346354
K: int,
347355
N: int,
356+
N_GROUPS: int,
348357
stride_input_row: int,
349358
stride_input_col: int,
350359
stride_output_row: int,

0 commit comments

Comments
 (0)