Skip to content

Commit ba33d18

Browse files
author
shekhar.pandey@amd.com
committed
[ROCm] MXFP8 MoE: persistent grouped kernel, F.scaled_mm dense, correctness + tests
- Persistent grouped-MM kernel (grid = num_CUs * ctas_per_cu); walks experts in-kernel with a global tile counter. Avoids silent row-dropping under (M+E-1)//E bounds and keeps the dispatcher torch.compile-clean. - Dense MXFP8 path: dispatch to F.scaled_mm with BlockWise1x32. - Wgrad: retune default tile to (BN=256, BK=256, BM=64, nw=8). - K-tail and scale-tail masking; m_mask bounded by group_end and global M. - torch.compile: register pad/unpad helpers as torch.library.custom_op; skip nonstrict_trace on ROCm. - mx_linear / MXFP8TrainingOpConfig: drop is_ROCM() auto-switch; expose mxfp8_dim1_cast_kernel_choice as explicit arg (CUDA default). - bench_2d_3d_grouped_gemm.py: run on MI350+ via bench_mxfp8_grouped_mm_rocm; fix flops formula = 2 * M * N * K. Tested on MI355X / gfx950 / ROCm 7.1 / Triton 3.7: Accuracy: test/prototype/moe_training/test_mxfp8_grouped_mm.py -> 129 passed, 16 skipped. SQNR margins: out >= 27.6 (>= 27), in_grad >= 25.2 (>= 25), w_grad >= 25.5 (>= 24). Perf: benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py
1 parent 42ff8ed commit ba33d18

13 files changed

Lines changed: 774 additions & 30 deletions

File tree

benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from torchao.prototype.moe_training.utils import generate_jagged_offs
2525
from torchao.prototype.mx_formats.mx_tensor import to_mx
26+
from torchao.utils import is_MI350
2627

2728
device = torch.device("cuda")
2829

@@ -115,13 +116,17 @@ def run_experiment(
115116
fp8_rowwise_us = bench_fp8_rowwise_grouped_mm(A, B_t, offs)
116117

117118
# benchmark mxfp8 grouped mm
118-
if torch.cuda.get_device_capability() != (10, 0):
119+
if torch.cuda.get_device_capability() == (10, 0):
120+
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
121+
elif is_MI350():
122+
mxfp8_us = bench_mxfp8_grouped_mm_rocm(A, B_t, offs)
123+
else:
119124
logging.warning(
120-
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
125+
f"Skipping MXFP8 benchmarks, only supported on CUDA SM 10.0 or MI350+ "
126+
f"(found device_capability={torch.cuda.get_device_capability()}, "
127+
f"hip={torch.version.hip})"
121128
)
122129
mxfp8_us = float("inf")
123-
else:
124-
mxfp8_us = bench_mxfp8_grouped_mm(A, B_t, offs)
125130

126131
return ExperimentResult(
127132
bf16_us=round(bf16_us, 3),
@@ -148,13 +153,12 @@ def print_results(experiments: List[Experiment]):
148153
rows = []
149154
for experiment in experiments:
150155
# calculate tflops
151-
e, m, n, k = (
152-
experiment.config.e,
156+
m, n, k = (
153157
experiment.config.m,
154158
experiment.config.n,
155159
experiment.config.k,
156160
)
157-
flops = 2 * e * m * n * k
161+
flops = 2 * m * n * k
158162
bf16_tflops = (flops / 1e12) / (experiment.result.bf16_us / 1e6)
159163
fp8_rowwise_tflops = (flops / 1e12) / (experiment.result.fp8_rowwise_us / 1e6)
160164
mxfp8_tflops = (flops / 1e12) / (experiment.result.mxfp8_us / 1e6)
@@ -247,6 +251,30 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:
247251
return mxfp8_us
248252

249253

254+
def bench_mxfp8_grouped_mm_rocm(A, B_t, offs, block_size=32) -> float:
255+
from torchao.prototype.moe_training.kernels.mxfp8.rocm_mxfp8_mm import (
256+
triton_mxfp8_grouped_mm,
257+
)
258+
259+
A_scales, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
260+
B_nkK = B_t.transpose(-2, -1).contiguous()
261+
B_scales, B_fp8 = to_mx(B_nkK, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
262+
263+
E = offs.shape[0]
264+
Mg = A.shape[0]
265+
offs_mxfp8 = generate_jagged_offs(E, Mg, multiple_of=block_size)
266+
267+
mxfp8_us = benchmark_cuda_function_in_microseconds(
268+
triton_mxfp8_grouped_mm,
269+
A_fp8,
270+
B_fp8,
271+
A_scales,
272+
B_scales,
273+
offs_mxfp8,
274+
)
275+
return mxfp8_us
276+
277+
250278
def main(args: argparse.Namespace):
251279
torch.random.manual_seed(123)
252280
configs = get_configs()

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,13 @@ def main(args: argparse.Namespace):
249249
elif config.recipe in (
250250
MXFP8TrainingRecipe.MXFP8_RCEIL,
251251
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
252-
) and torch.cuda.get_device_capability() != (10, 0):
252+
) and not (
253+
torch.cuda.get_device_capability() == (10, 0) or is_MI350()
254+
):
253255
logging.warning(
254-
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
256+
f"Skipping MXFP8 benchmarks, only supported on CUDA SM 10.0 or MI350+ "
257+
f"(found device_capability={torch.cuda.get_device_capability()}, "
258+
f"hip={torch.version.hip})"
255259
)
256260
continue
257261

test/prototype/moe_training/test_mxfp8_grouped_mm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@
4242
)
4343
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
4444
from torchao.quantization.quantize_.common import KernelPreference
45-
from torchao.testing.utils import skip_if_rocm
4645

4746
# Needed since changing args to function causes recompiles
4847
torch._dynamo.config.cache_size_limit = 1000
4948

5049

51-
@skip_if_rocm("ROCm not supported")
5250
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
5351
@pytest.mark.parametrize("num_experts", (1, 8, 16))
5452
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
@@ -80,7 +78,6 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
8078
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
8179

8280

83-
@skip_if_rocm("ROCm not supported")
8481
@pytest.mark.parametrize("M", (1024, 4096))
8582
@pytest.mark.parametrize("N", (1024, 4096))
8683
@pytest.mark.parametrize("num_experts", (8, 16))
@@ -128,7 +125,6 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
128125
assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}"
129126

130127

131-
@skip_if_rocm("ROCm not supported")
132128
@pytest.mark.parametrize("M,K,N", [(32768, 5120, 8192), (16640, 7168, 2048)])
133129
@pytest.mark.parametrize("num_experts", (1, 8))
134130
@pytest.mark.parametrize("wgrad_with_hp", (True, False))
@@ -152,7 +148,7 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
152148
pad_token_groups_for_grouped_mm,
153149
):
154150
# MXFP8 hardware path requires SM100
155-
if kernel_preference != KernelPreference.EMULATED and not is_sm_version(10, 0):
151+
if kernel_preference != KernelPreference.EMULATED and not (is_sm_version(10, 0) or is_MI350()):
156152
pytest.skip(
157153
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
158154
)
@@ -225,7 +221,6 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
225221
)
226222

227223

228-
@skip_if_rocm("ROCm not supported")
229224
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
230225
block_size = 32
231226
M, K, N, num_experts = 4096, 1024, 2048, 8
@@ -298,7 +293,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
298293
)
299294

300295

301-
@skip_if_rocm("ROCm not supported")
302296
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
303297
block_size = 32
304298
M, K, N, num_experts = 4096, 1024, 2048, 8
@@ -352,7 +346,6 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
352346
)
353347

354348

355-
@skip_if_rocm("ROCm not supported")
356349
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp():
357350
block_size = 32
358351
M, K, N, num_experts = 1024, 1024, 2048, 4

torchao/prototype/moe_training/config.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from torch import nn
1313

1414
from torchao.core.config import AOBaseConfig
15-
from torchao.prototype.mx_formats.config import ScaleCalculationMode
15+
from torchao.prototype.mx_formats.config import (
16+
MXFP8Dim1CastKernelChoice,
17+
ScaleCalculationMode,
18+
)
1619
from torchao.quantization.quantize_.common import KernelPreference
1720
from torchao.quantization.transform_module import register_quantize_module_handler
1821
from torchao.utils import is_MI300, register_as_pytree_constant
@@ -131,6 +134,13 @@ class MXFP8TrainingOpConfig(TrainingOpBaseConfig):
131134
# Whether to pad the token group sizes to multiples of 32 (MXFP8 scaling block size).
132135
pad_token_groups_for_grouped_mm: bool = False
133136

137+
# Kernel used for the MXFP8 dim1 cast in backward (wgrad path). Default is
138+
# CUDA (best on CUDA SM100+). On backends without the CUDA kernel (e.g.
139+
# ROCm), set to MXFP8Dim1CastKernelChoice.TRITON.
140+
mxfp8_dim1_cast_kernel_choice: MXFP8Dim1CastKernelChoice = (
141+
MXFP8Dim1CastKernelChoice.CUDA
142+
)
143+
134144
@classmethod
135145
def from_recipe(
136146
cls,
@@ -173,6 +183,8 @@ def __eq__(self, other):
173183
and self.scale_calculation_mode == other.scale_calculation_mode
174184
and self.pad_token_groups_for_grouped_mm
175185
== other.pad_token_groups_for_grouped_mm
186+
and self.mxfp8_dim1_cast_kernel_choice
187+
== other.mxfp8_dim1_cast_kernel_choice
176188
)
177189
return NotImplemented
178190

@@ -184,6 +196,7 @@ def __hash__(self):
184196
self.wgrad_with_hp,
185197
self.scale_calculation_mode,
186198
self.pad_token_groups_for_grouped_mm,
199+
self.mxfp8_dim1_cast_kernel_choice,
187200
)
188201
)
189202

torchao/prototype/moe_training/kernels/mxfp8/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@
1515
triton_mx_block_rearrange_2d_M_groups, # noqa: F401
1616
triton_mx_block_rearrange_per_group_3d, # noqa: F401
1717
)
18+
from torchao.prototype.moe_training.kernels.mxfp8.rocm_mxfp8_mm import (
19+
triton_mxfp8_grouped_mm, # noqa: F401
20+
triton_mxfp8_wgrad, # noqa: F401
21+
)

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def compute_blocked_scale_offsets_for_K_groups(
260260
return group_sizes, starting_col_after_padding
261261

262262

263+
@torch.library.custom_op("torchao::torch_pad_token_groups", mutates_args=())
263264
def torch_pad_token_groups(
264265
inputs: torch.Tensor, group_offsets: torch.Tensor, alignment_size: int
265266
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -323,6 +324,27 @@ def torch_pad_token_groups(
323324
return padded_tokens, padded_start_offsets, padded_offsets
324325

325326

327+
@torch_pad_token_groups.register_fake
328+
def _torch_pad_token_groups_fake(
329+
inputs: torch.Tensor, group_offsets: torch.Tensor, alignment_size: int
330+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
331+
num_tokens, dim = inputs.shape
332+
num_groups = group_offsets.shape[0]
333+
output_rows = num_tokens + num_groups * alignment_size
334+
output_rows = (
335+
(output_rows + alignment_size - 1) // alignment_size
336+
) * alignment_size
337+
padded_tokens = inputs.new_empty((output_rows, dim))
338+
padded_group_start_offsets = torch.empty(
339+
(num_groups,), dtype=torch.int32, device=inputs.device
340+
)
341+
padded_group_end_offsets = torch.empty(
342+
(num_groups,), dtype=torch.int32, device=inputs.device
343+
)
344+
return padded_tokens, padded_group_start_offsets, padded_group_end_offsets
345+
346+
347+
@torch.library.custom_op("torchao::torch_unpad_token_groups", mutates_args=())
326348
def torch_unpad_token_groups(
327349
padded_inputs: torch.Tensor,
328350
group_offsets: torch.Tensor,
@@ -373,6 +395,18 @@ def torch_unpad_token_groups(
373395
return unpadded_tokens
374396

375397

398+
@torch_unpad_token_groups.register_fake
399+
def _torch_unpad_token_groups_fake(
400+
padded_inputs: torch.Tensor,
401+
group_offsets: torch.Tensor,
402+
padded_group_start_offsets: torch.Tensor,
403+
num_tokens: int,
404+
alignment_size: int,
405+
) -> torch.Tensor:
406+
dim = padded_inputs.shape[1]
407+
return padded_inputs.new_empty((num_tokens, dim))
408+
409+
376410
if torch_version_at_least("2.7.0") and has_triton():
377411
import triton
378412
import triton.language as tl

0 commit comments

Comments
 (0)