Commit ba33d18
shekhar.pandey@amd.com
[ROCm] MXFP8 MoE: persistent grouped kernel, F.scaled_mm dense, correctness + tests
- Persistent grouped-MM kernel (grid = num_CUs * ctas_per_cu); walks experts
in-kernel with a global tile counter. Avoids silent row-dropping under
(M+E-1)//E bounds and keeps the dispatcher torch.compile-clean.
- Dense MXFP8 path: dispatch to F.scaled_mm with BlockWise1x32.
- Wgrad: retune default tile to (BN=256, BK=256, BM=64, nw=8).
- K-tail and scale-tail masking; m_mask bounded by group_end and global M.
- torch.compile: register pad/unpad helpers as torch.library.custom_op;
skip nonstrict_trace on ROCm.
- mx_linear / MXFP8TrainingOpConfig: drop is_ROCM() auto-switch; expose
mxfp8_dim1_cast_kernel_choice as explicit arg (CUDA default).
- bench_2d_3d_grouped_gemm.py: run on MI350+ via bench_mxfp8_grouped_mm_rocm;
fix flops formula = 2 * M * N * K.
Tested on MI355X / gfx950 / ROCm 7.1 / Triton 3.7:
Accuracy: test/prototype/moe_training/test_mxfp8_grouped_mm.py
-> 129 passed, 16 skipped.
SQNR margins: out >= 27.6 (>= 27), in_grad >= 25.2 (>= 25),
w_grad >= 25.5 (>= 24).
Perf: benchmarks/prototype/moe_training/bench_2d_3d_grouped_gemm.py1 parent 42ff8ed commit ba33d18
13 files changed
Lines changed: 774 additions & 30 deletions
File tree
- benchmarks/prototype/moe_training
- test/prototype/moe_training
- torchao/prototype
- moe_training
- kernels/mxfp8
- mx_formats
Lines changed: 35 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
26 | 27 | | |
27 | 28 | | |
28 | 29 | | |
| |||
115 | 116 | | |
116 | 117 | | |
117 | 118 | | |
118 | | - | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
119 | 124 | | |
120 | | - | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
121 | 128 | | |
122 | 129 | | |
123 | | - | |
124 | | - | |
125 | 130 | | |
126 | 131 | | |
127 | 132 | | |
| |||
148 | 153 | | |
149 | 154 | | |
150 | 155 | | |
151 | | - | |
152 | | - | |
| 156 | + | |
153 | 157 | | |
154 | 158 | | |
155 | 159 | | |
156 | 160 | | |
157 | | - | |
| 161 | + | |
158 | 162 | | |
159 | 163 | | |
160 | 164 | | |
| |||
247 | 251 | | |
248 | 252 | | |
249 | 253 | | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
250 | 278 | | |
251 | 279 | | |
252 | 280 | | |
| |||
Lines changed: 6 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
249 | 249 | | |
250 | 250 | | |
251 | 251 | | |
252 | | - | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
253 | 255 | | |
254 | | - | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
255 | 259 | | |
256 | 260 | | |
257 | 261 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
46 | 45 | | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
50 | 49 | | |
51 | | - | |
52 | 50 | | |
53 | 51 | | |
54 | 52 | | |
| |||
80 | 78 | | |
81 | 79 | | |
82 | 80 | | |
83 | | - | |
84 | 81 | | |
85 | 82 | | |
86 | 83 | | |
| |||
128 | 125 | | |
129 | 126 | | |
130 | 127 | | |
131 | | - | |
132 | 128 | | |
133 | 129 | | |
134 | 130 | | |
| |||
152 | 148 | | |
153 | 149 | | |
154 | 150 | | |
155 | | - | |
| 151 | + | |
156 | 152 | | |
157 | 153 | | |
158 | 154 | | |
| |||
225 | 221 | | |
226 | 222 | | |
227 | 223 | | |
228 | | - | |
229 | 224 | | |
230 | 225 | | |
231 | 226 | | |
| |||
298 | 293 | | |
299 | 294 | | |
300 | 295 | | |
301 | | - | |
302 | 296 | | |
303 | 297 | | |
304 | 298 | | |
| |||
352 | 346 | | |
353 | 347 | | |
354 | 348 | | |
355 | | - | |
356 | 349 | | |
357 | 350 | | |
358 | 351 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
16 | 19 | | |
17 | 20 | | |
18 | 21 | | |
| |||
131 | 134 | | |
132 | 135 | | |
133 | 136 | | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
134 | 144 | | |
135 | 145 | | |
136 | 146 | | |
| |||
173 | 183 | | |
174 | 184 | | |
175 | 185 | | |
| 186 | + | |
| 187 | + | |
176 | 188 | | |
177 | 189 | | |
178 | 190 | | |
| |||
184 | 196 | | |
185 | 197 | | |
186 | 198 | | |
| 199 | + | |
187 | 200 | | |
188 | 201 | | |
189 | 202 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
260 | 260 | | |
261 | 261 | | |
262 | 262 | | |
| 263 | + | |
263 | 264 | | |
264 | 265 | | |
265 | 266 | | |
| |||
323 | 324 | | |
324 | 325 | | |
325 | 326 | | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
326 | 348 | | |
327 | 349 | | |
328 | 350 | | |
| |||
373 | 395 | | |
374 | 396 | | |
375 | 397 | | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
376 | 410 | | |
377 | 411 | | |
378 | 412 | | |
| |||
0 commit comments