Skip to content

Commit 5e79e5e

Browse files
[mxfp8 moe training] update readme (pytorch#4084)
1 parent c0da952 commit 5e79e5e

1 file changed

Lines changed: 31 additions & 6 deletions

File tree

torchao/prototype/moe_training/README.md

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
## Overview
1818
This prototype provides:
1919

20-
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
20+
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/main/torchao/prototype/moe_training/mxfp8_grouped_mm.py). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#_to_mxfp8_then_scaled_grouped_mm-usage) of a forward and backward pass below.
2121
- Using MXFP8 on a B200 GPU, this provides:
2222
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
2323
- **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
@@ -26,12 +26,12 @@ This prototype provides:
2626

2727
2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"`
2828

29-
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
29+
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example) below).
3030

3131

32-
## Case study training at scale: 1.2x e2e speedup with equivalent convergence versus to bf16
32+
## Case study training at scale: 1.3x e2e speedup with equivalent convergence versus to bf16
3333

34-
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout demonstrated a 1.2x e2e training speedup with equivalent convergence to bfloat16 training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).
34+
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout demonstrated a [1.3x e2e training speedup with equivalent convergence to bfloat16](https://pytorch.org/blog/mxfp8-training-for-moes-1-3x-training-speedup-vs-bf16-for-llama4-scout-on-gb200-cluster-using-torchao-and-torchtitan/) training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).
3535

3636
<img alt="Image" src="../../../docs/static/mxfp8_with_loss.png" />
3737

@@ -59,7 +59,7 @@ Training and model configurations for this run:
5959

6060

6161
## Examples
62-
#### _to_mxfp8_and_scaled_grouped_mm usage
62+
#### _to_mxfp8_then_scaled_grouped_mm usage
6363
```python
6464
import torch
6565
from torch.nn import functional as F
@@ -95,6 +95,31 @@ loss = F.mse_loss(out, labels)
9595
loss.backward()
9696
```
9797

98+
#### Model conversion API example
99+
```python
100+
import torch
101+
from torch import nn
102+
from torchao.quantization.quant_api import quantize_
103+
from torchao.prototype.moe_training.config import (
104+
MXFP8TrainingOpConfig,
105+
MXFP8TrainingRecipe,
106+
)
107+
108+
# Example: MoE model with experts that use torch._grouped_mm
109+
model = YourMoEModel().cuda().to(torch.bfloat16)
110+
111+
# Create MXFP8 config from a recipe
112+
config = MXFP8TrainingOpConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL)
113+
114+
# Use filter_fn to target only expert modules
115+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
116+
return "experts" in cur_fqn
117+
118+
# This swaps expert nn.Parameter data with MXFP8TrainingWeightWrapperTensor,
119+
# which overrides torch._grouped_mm to use MXFP8 quantized grouped GEMMs.
120+
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
121+
```
122+
98123
## System requirements
99124
- torchao 0.14+
100125
- For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required.
@@ -104,7 +129,7 @@ loss.backward()
104129
## Benchmarks
105130

106131
### Autograd function
107-
Forward + backward pass benchmarks for the [autograd function](https://github.com/pytorch/ao/blob/8bb433e989ad6f7ee0920f946d3a9be7f14be8c7/torchao/prototype/moe_training/scaled_grouped_mm.py#L284) powering MXFP8 MoE training.
132+
Forward + backward pass benchmarks for the [autograd function](https://github.com/pytorch/ao/blob/main/torchao/prototype/moe_training/mxfp8_grouped_mm.py) powering MXFP8 MoE training.
108133

109134
#### Llama4 Scout shapes
110135

0 commit comments

Comments
 (0)