Skip to content

Commit e654d74

Browse files
authored
Add pre-quantized activation support to MXFP8 grouped GEMM (_to_mxfp8_then_scaled_grouped_mm) (#3961)
* moe: support MXTensor helper for prequantized mxfp8 grouped mm # Conflicts: # torchao/prototype/moe_training/mxfp8_grouped_mm.py * lint: apply ruff format in mxfp8 grouped mm test * mx: derive helper dtype from qdata
1 parent 95d366c commit e654d74

5 files changed

Lines changed: 350 additions & 9 deletions

File tree

test/prototype/moe_training/test_mxfp8_grouped_mm.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
_to_mxfp8_per_group_rowwise,
4141
generate_jagged_offs,
4242
)
43-
from torchao.prototype.mx_formats.mx_tensor import to_mx
43+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
4444
from torchao.quantization.quantize_.common import KernelPreference
4545
from torchao.testing.utils import skip_if_rocm
4646

@@ -225,3 +225,176 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
225225
assert sqnr >= min_weight_grad_sqnr, (
226226
f"Weight grad sqnr {sqnr} is too low, must be >= {min_weight_grad_sqnr}"
227227
)
228+
229+
230+
@skip_if_rocm("ROCm not supported")
231+
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
232+
block_size = 32
233+
M, K, N, num_experts = 4096, 1024, 2048, 8
234+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
235+
w = torch.randn(
236+
num_experts,
237+
N,
238+
K,
239+
dtype=torch.bfloat16,
240+
device="cuda",
241+
)
242+
w_t = w.transpose(-2, -1).requires_grad_(True)
243+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
244+
245+
x_ref = x.detach().clone().requires_grad_(True)
246+
w_t_ref = w_t.detach().clone().requires_grad_(True)
247+
248+
x_scale, x_qdata = to_mx(
249+
x.detach(),
250+
elem_dtype=torch.float8_e4m3fn,
251+
block_size=block_size,
252+
scaling_mode=ScaleCalculationMode.RCEIL,
253+
)
254+
x_mx = MXTensor.from_qdata_and_scales(
255+
x_qdata,
256+
x_scale,
257+
orig_dtype=x.dtype,
258+
block_size=block_size,
259+
is_swizzled_scales=False,
260+
)
261+
out = _to_mxfp8_then_scaled_grouped_mm(
262+
x_mx,
263+
w_t,
264+
offs=offs,
265+
block_size=block_size,
266+
out_dtype=torch.bfloat16,
267+
kernel_preference=KernelPreference.EMULATED,
268+
wgrad_with_hp=True,
269+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
270+
)
271+
out_ref = _to_mxfp8_then_scaled_grouped_mm(
272+
x_ref,
273+
w_t_ref,
274+
offs=offs,
275+
block_size=block_size,
276+
out_dtype=torch.bfloat16,
277+
kernel_preference=KernelPreference.EMULATED,
278+
wgrad_with_hp=True,
279+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
280+
)
281+
282+
output_sqnr = compute_error(out_ref, out)
283+
min_output_sqnr = 60.0
284+
assert output_sqnr >= min_output_sqnr, (
285+
f"Output sqnr {output_sqnr} is too low, must be >= {min_output_sqnr}"
286+
)
287+
288+
labels = torch.ones_like(out_ref)
289+
F.mse_loss(out_ref, labels).backward()
290+
F.mse_loss(out, labels).backward()
291+
292+
assert x.grad is None, (
293+
"MXTensor inputs are not connected back to the source HP tensor"
294+
)
295+
296+
weight_grad_sqnr = compute_error(w_t_ref.grad, w_t.grad)
297+
# MXTensor inputs dequantize for the `wgrad_with_hp` path, so the weight
298+
# gradient is expected to be close to, but not identical to, the HP path.
299+
min_weight_grad_sqnr = 30.0
300+
assert weight_grad_sqnr >= min_weight_grad_sqnr, (
301+
f"Weight grad sqnr {weight_grad_sqnr} is too low, must be >= {min_weight_grad_sqnr}"
302+
)
303+
304+
305+
@skip_if_rocm("ROCm not supported")
306+
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
307+
block_size = 32
308+
M, K, N, num_experts = 4096, 1024, 2048, 8
309+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
310+
w = torch.randn(
311+
num_experts,
312+
N,
313+
K,
314+
dtype=torch.bfloat16,
315+
device="cuda",
316+
)
317+
w_t = w.transpose(-2, -1)
318+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
319+
320+
x_scale, x_qdata = to_mx(
321+
x.detach(),
322+
elem_dtype=torch.float8_e4m3fn,
323+
block_size=block_size,
324+
scaling_mode=ScaleCalculationMode.RCEIL,
325+
)
326+
x_mx = MXTensor.from_qdata_and_scales(
327+
x_qdata,
328+
x_scale,
329+
orig_dtype=x.dtype,
330+
block_size=block_size,
331+
is_swizzled_scales=False,
332+
)
333+
out_mx = _to_mxfp8_then_scaled_grouped_mm(
334+
x_mx,
335+
w_t,
336+
offs=offs,
337+
block_size=block_size,
338+
out_dtype=torch.bfloat16,
339+
kernel_preference=KernelPreference.EMULATED,
340+
wgrad_with_hp=True,
341+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
342+
)
343+
out_ref = _to_mxfp8_then_scaled_grouped_mm(
344+
x,
345+
w_t,
346+
offs=offs,
347+
block_size=block_size,
348+
out_dtype=torch.bfloat16,
349+
kernel_preference=KernelPreference.EMULATED,
350+
wgrad_with_hp=True,
351+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
352+
)
353+
354+
output_sqnr = compute_error(out_ref, out_mx)
355+
min_output_sqnr = 60.0
356+
assert output_sqnr >= min_output_sqnr, (
357+
f"Output sqnr {output_sqnr} is too low, must be >= {min_output_sqnr}"
358+
)
359+
360+
361+
@skip_if_rocm("ROCm not supported")
362+
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp():
363+
block_size = 32
364+
M, K, N, num_experts = 1024, 1024, 2048, 4
365+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
366+
w = torch.randn(
367+
num_experts,
368+
N,
369+
K,
370+
dtype=torch.bfloat16,
371+
device="cuda",
372+
)
373+
w_t = w.transpose(-2, -1)
374+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
375+
376+
x_scale, x_qdata = to_mx(
377+
x,
378+
elem_dtype=torch.float8_e4m3fn,
379+
block_size=block_size,
380+
scaling_mode=ScaleCalculationMode.RCEIL,
381+
)
382+
x_mx = MXTensor.from_qdata_and_scales(
383+
x_qdata,
384+
x_scale,
385+
orig_dtype=x.dtype,
386+
block_size=block_size,
387+
is_swizzled_scales=False,
388+
)
389+
390+
with pytest.raises(AssertionError, match="wgrad_with_hp"):
391+
_to_mxfp8_then_scaled_grouped_mm(
392+
x_mx,
393+
w_t,
394+
offs=offs,
395+
block_size=block_size,
396+
out_dtype=torch.bfloat16,
397+
kernel_preference=KernelPreference.EMULATED,
398+
wgrad_with_hp=False,
399+
scale_calculation_mode=ScaleCalculationMode.RCEIL,
400+
)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,66 @@ def test_block_sizes(elem_dtype, B):
412412
_test_mx(tensor_hp, elem_dtype, B)
413413

414414

415+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
416+
def test_from_qdata_and_scales_round_trip():
417+
tensor_hp = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
418+
tensor_mx = MXTensor.to_mx(
419+
tensor_hp,
420+
torch.float8_e4m3fn,
421+
32,
422+
ScaleCalculationMode.RCEIL,
423+
)
424+
rebuilt = MXTensor.from_qdata_and_scales(
425+
tensor_mx.qdata,
426+
tensor_mx.scale,
427+
orig_dtype=tensor_hp.dtype,
428+
block_size=32,
429+
)
430+
torch.testing.assert_close(
431+
rebuilt.dequantize(torch.float32),
432+
tensor_mx.dequantize(torch.float32),
433+
)
434+
assert rebuilt.elem_dtype == tensor_mx.elem_dtype
435+
assert rebuilt.block_size == tensor_mx.block_size
436+
assert rebuilt.orig_dtype == tensor_mx.orig_dtype
437+
438+
439+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
440+
def test_from_qdata_and_scales_requires_float8_e8m0_scale_dtype():
441+
tensor_hp = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
442+
tensor_mx = MXTensor.to_mx(
443+
tensor_hp,
444+
torch.float8_e4m3fn,
445+
32,
446+
ScaleCalculationMode.RCEIL,
447+
)
448+
with pytest.raises(AssertionError, match="scale.dtype"):
449+
MXTensor.from_qdata_and_scales(
450+
tensor_mx.qdata,
451+
tensor_mx.scale.view(torch.uint8),
452+
orig_dtype=tensor_hp.dtype,
453+
block_size=32,
454+
)
455+
456+
457+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
458+
def test_from_qdata_and_scales_rejects_packed_uint8_qdata():
459+
tensor_hp = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
460+
tensor_mx = MXTensor.to_mx(
461+
tensor_hp,
462+
torch.float8_e4m3fn,
463+
32,
464+
ScaleCalculationMode.RCEIL,
465+
)
466+
with pytest.raises(AssertionError, match="typed MX qdata"):
467+
MXTensor.from_qdata_and_scales(
468+
torch.zeros_like(tensor_mx.qdata, dtype=torch.uint8),
469+
tensor_mx.scale,
470+
orig_dtype=tensor_hp.dtype,
471+
block_size=32,
472+
)
473+
474+
415475
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
416476
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
417477
def test_transpose(elem_dtype):

torchao/prototype/moe_training/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Training and model configurations for this run:
6363
```python
6464
import torch
6565
from torch.nn import functional as F
66+
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
6667
from torchao.prototype.moe_training import (
6768
_to_mxfp8_then_scaled_grouped_mm,
6869
)
@@ -83,6 +84,10 @@ out = _to_mxfp8_then_scaled_grouped_mm(
8384
B.transpose(-2, -1),
8485
offs,
8586
)
87+
# Optional: if you already have raw MXFP8 qdata/scales, wrap them as an MXTensor:
88+
# A_scale, A_qdata = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=32)
89+
# A_mx = MXTensor.from_qdata_and_scales(A_qdata, A_scale, orig_dtype=A.dtype)
90+
# out = _to_mxfp8_then_scaled_grouped_mm(A_mx, B.transpose(-2, -1), offs, wgrad_with_hp=True)
8691

8792
# (Fake labels for demonstration purposes)
8893
labels = torch.ones_like(out)

torchao/prototype/moe_training/mxfp8_grouped_mm.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@
4949
)
5050

5151

52+
def _validate_grouped_mm_input_act(
53+
input_act: torch.Tensor,
54+
block_size: int,
55+
) -> None:
56+
if not isinstance(input_act, MXTensor):
57+
return
58+
59+
assert input_act.elem_dtype == torch.float8_e4m3fn, (
60+
f"Expected MXTensor with elem_dtype float8_e4m3fn, but got {input_act.elem_dtype}"
61+
)
62+
assert input_act.block_size == block_size, (
63+
f"Expected MXTensor block_size={block_size}, but got {input_act.block_size}"
64+
)
65+
assert not input_act.is_swizzled_scales, (
66+
"MXTensor input scales must be unswizzled for grouped GEMM"
67+
)
68+
assert input_act.qdata.ndim == 2, "MXTensor input_act data must be 2D"
69+
assert input_act.scale.ndim == 2, "MXTensor input_act scale must be 2D"
70+
assert input_act.scale.shape == (
71+
input_act.shape[0],
72+
input_act.shape[1] // block_size,
73+
), (
74+
"MXTensor input scales must be rowwise with shape "
75+
f"({input_act.shape[0]}, {input_act.shape[1] // block_size})"
76+
)
77+
78+
5279
# Aliases for convenience/clarity
5380
# @conditional_nostrict_trace
5481
def _to_mxfp8_then_scaled_grouped_mm(
@@ -66,9 +93,11 @@ def _to_mxfp8_then_scaled_grouped_mm(
6693
Differentiable mxfp8 grouped gemm with dynamic mxfp8 quantization.
6794
6895
Args:
69-
A (bf16/float32 torch.Tensor): The first high-precision input tensor,
70-
which must be a 2D tensor of shape (M * num_groups, K)
71-
and in row-major memory layout.
96+
A (torch.Tensor): Input activations. May be a high-precision 2D tensor of
97+
shape (M * num_groups, K) in row-major memory layout, or an `MXTensor`
98+
carrying pre-quantized MXFP8 activations. If you already have raw
99+
`(qdata, scale)` tensors, wrap them first with
100+
`MXTensor.from_qdata_and_scales(...)`.
72101
B_t (bf16/float32 torch.Tensor): The second high-precision input tensor
73102
which must be 3D, which must be shape (G, K, N)
74103
and in "per group column-major memory" layout (i.e., strides of (N*K, 1, N)).
@@ -85,6 +114,7 @@ def _to_mxfp8_then_scaled_grouped_mm(
85114
"""
86115
# block_size is always 32 for MXFP8
87116
block_size = 32
117+
_validate_grouped_mm_input_act(A, block_size)
88118
return _MXFP8GroupedMM.apply(
89119
A,
90120
B_t,
@@ -103,8 +133,8 @@ class _MXFP8GroupedMM(torch.autograd.Function):
103133
Differentiable implementation of grouped GEMM with dynamic MXFP8 quantization.
104134
105135
This autograd function performs grouped matrix multiplication with MXFP8 quantization
106-
for efficient MoE training. It supports both pre-quantized (MXTensor) and high-precision
107-
inputs, with configurable quantization and layout conversion options.
136+
for efficient MoE training. It supports both pre-quantized (`MXTensor`) and
137+
high-precision inputs, with configurable quantization and layout conversion options.
108138
"""
109139

110140
@staticmethod
@@ -161,7 +191,9 @@ def forward(
161191
), "out_dtype must be bfloat16 or float32"
162192
if isinstance(input_act, MXTensor):
163193
assert wgrad_with_hp, (
164-
"only `wgrad_with_hp` recipe is supported for pre-quantized inputs, support for other recipes is still in progress"
194+
"only `wgrad_with_hp` recipe is supported for MXTensor inputs because "
195+
"backward needs the high-precision activations to quantize along dim1 "
196+
"for weight gradients"
165197
)
166198

167199
# Save original group_end_offsets and num_tokens before padding
@@ -347,8 +379,17 @@ def backward(ctx, grad_output: torch.Tensor):
347379
wgrad_with_hp,
348380
kernel_preference,
349381
)
350-
351-
return grad_input, grad_weight_t, None, None, None, None, None, None, None
382+
return (
383+
grad_input,
384+
grad_weight_t,
385+
None,
386+
None,
387+
None,
388+
None,
389+
None,
390+
None,
391+
None,
392+
)
352393

353394

354395
def _compute_dgrad(

0 commit comments

Comments
 (0)