Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions torchao/prototype/moe_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
## Overview
This prototype provides:

1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
1. Quantized building block for low precision MoE training: [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/main/torchao/prototype/moe_training/mxfp8_grouped_mm.py). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#_to_mxfp8_then_scaled_grouped_mm-usage) of a forward and backward pass below.
- Using MXFP8 on a B200 GPU, this provides:
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
- **~1.19 - 1.6x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
Expand All @@ -26,12 +26,12 @@ This prototype provides:

2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"`

3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example) below).


## Case study training at scale: 1.2x e2e speedup with equivalent convergence versus to bf16
## Case study training at scale: 1.3x e2e speedup with equivalent convergence versus to bf16

Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout demonstrated a 1.2x e2e training speedup with equivalent convergence to bfloat16 training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).
Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout demonstrated a [1.3x e2e training speedup with equivalent convergence to bfloat16](https://pytorch.org/blog/mxfp8-training-for-moes-1-3x-training-speedup-vs-bf16-for-llama4-scout-on-gb200-cluster-using-torchao-and-torchtitan/) training baseline. Infact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).

<img alt="Image" src="../../../docs/static/mxfp8_with_loss.png" />

Expand Down Expand Up @@ -59,7 +59,7 @@ Training and model configurations for this run:


## Examples
#### _to_mxfp8_and_scaled_grouped_mm usage
#### _to_mxfp8_then_scaled_grouped_mm usage
```python
import torch
from torch.nn import functional as F
Expand Down Expand Up @@ -90,6 +90,31 @@ loss = F.mse_loss(out, labels)
loss.backward()
```

#### Model conversion API example
```python
import torch
from torch import nn
from torchao.quantization.quant_api import quantize_
from torchao.prototype.moe_training.config import (
MXFP8TrainingOpConfig,
MXFP8TrainingRecipe,
)

# Example: MoE model with experts that use torch._grouped_mm
model = YourMoEModel().cuda().to(torch.bfloat16)

# Create MXFP8 config from a recipe
config = MXFP8TrainingOpConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL)

# Use filter_fn to target only expert modules
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return "experts" in cur_fqn

# This swaps expert nn.Parameter data with MXFP8TrainingWeightWrapperTensor,
# which overrides torch._grouped_mm to use MXFP8 quantized grouped GEMMs.
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
```

## System requirements
- torchao 0.14+
- For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required.
Expand All @@ -99,7 +124,7 @@ loss.backward()
## Benchmarks

### Autograd function
Forward + backward pass benchmarks for the [autograd function](https://github.com/pytorch/ao/blob/8bb433e989ad6f7ee0920f946d3a9be7f14be8c7/torchao/prototype/moe_training/scaled_grouped_mm.py#L284) powering MXFP8 MoE training.
Forward + backward pass benchmarks for the [autograd function](https://github.com/pytorch/ao/blob/main/torchao/prototype/moe_training/mxfp8_grouped_mm.py) powering MXFP8 MoE training.

#### Llama4 Scout shapes

Expand Down
Loading