Skip to content

Commit b8708a2

Browse files
TrainingWeightWrapperTensor base class; subclasses for FP8/MXFP8 with grouped_mm and linear overrides (#3968)
* [mxfp8 training] unified tensor subclass for training * [mxfp8 training] remove mxfp8 from MXLinear and MXLinearConfig * [moe training] unified tensor subclass for training * delete MXLinear and MXLinearConfig entirely
1 parent 7bb7f06 commit b8708a2

33 files changed

Lines changed: 973 additions & 1640 deletions

.github/workflows/1xH100_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@ jobs:
5757
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py
5858
./test/float8/test_everything_single_gpu.sh
5959
pytest test/prototype/mx_formats/ --verbose -s
60-
pytest test/prototype/moe_training/test_scaled_grouped_mm.py --verbose -s
60+
pytest test/prototype/moe_training/test_fp8_grouped_mm.py --verbose -s
61+
pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py --verbose -s
6162
pytest test/prototype/moe_training/test_training.py --verbose -s

benchmarks/float8/bench_matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def run(
4242
assert recipe in (
4343
"tensorwise",
4444
"rowwise",
45-
"mxfp8_cublas",
4645
"mxfp4_cutlass",
4746
"nvfp4",
4847
), "unsupported"

benchmarks/float8/float8_roofline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@
6161
Float8LinearConfig,
6262
convert_to_float8_training,
6363
)
64-
from torchao.prototype.mx_formats import MXLinearConfig
64+
from torchao.prototype.moe_training.config import (
65+
MXFP8TrainingOpConfig,
66+
MXFP8TrainingRecipe,
67+
)
6568
from torchao.quantization import quantize_
6669
from torchao.testing.training.roofline_utils import (
6770
get_float8_mem_sympy,
@@ -253,10 +256,7 @@ def run(
253256
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
254257

255258
assert mx_recipe_name in (
256-
# real mxfp8_cublas recipe
257-
"mxfp8_cublas",
258-
# real mxfp8_cublas_rceil recipe
259-
"mxfp8_cublas_rceil",
259+
None,
260260
# modeling of what mxfp8 with 32x32 block size and without gemm
261261
# operand layout restrictions would look like
262262
"mxfp8_32x32_flexible_gemm_layout",
@@ -429,7 +429,15 @@ def run(
429429
)
430430
else:
431431
assert mx_recipe_name is not None
432-
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
432+
try:
433+
config = MXFP8TrainingOpConfig.from_recipe(
434+
MXFP8TrainingRecipe(mx_recipe_name)
435+
)
436+
except ValueError:
437+
raise ValueError(
438+
f"Unsupported mx_recipe_name: {mx_recipe_name}. "
439+
f"Supported values: {[r.value for r in MXFP8TrainingRecipe]}"
440+
)
433441
m_fp8_dyn = copy.deepcopy(m_orig)
434442
quantize_(m_fp8_dyn, config=config)
435443
m_fp8_dyn = torch.compile(m_fp8_dyn)

benchmarks/float8/profile_lowp_training.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
from torchao.float8.float8_linear_utils import (
4646
convert_to_float8_training,
4747
)
48-
from torchao.prototype.mx_formats.config import MXLinearConfig
48+
from torchao.prototype.moe_training.config import (
49+
MXFP8TrainingOpConfig,
50+
MXFP8TrainingRecipe,
51+
)
4952
from torchao.prototype.mx_formats.mx_tensor import MXTensor
5053
from torchao.prototype.mx_formats.utils import to_blocked
5154
from torchao.quantization import quantize_
@@ -320,7 +323,15 @@ def main(
320323
elif float8_recipe_name is not None:
321324
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
322325
elif mx_recipe_name is not None:
323-
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
326+
try:
327+
config = MXFP8TrainingOpConfig.from_recipe(
328+
MXFP8TrainingRecipe(mx_recipe_name)
329+
)
330+
except ValueError:
331+
raise ValueError(
332+
f"Unsupported mx_recipe_name: {mx_recipe_name}. "
333+
f"Supported values: {[r.value for r in MXFP8TrainingRecipe]}"
334+
)
324335

325336
print(f"Compile is set to | {compile}")
326337
print(f"model_type is set to | {model_type}")

benchmarks/prototype/moe_training/bench_moe_layer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
1818
from torchao.prototype.moe_training.config import (
19-
FP8GroupedMMRecipe,
20-
MXFP8GroupedMMConfig,
21-
MXFP8GroupedMMRecipe,
19+
Float8TrainingRecipe,
20+
MXFP8TrainingOpConfig,
21+
MXFP8TrainingRecipe,
2222
)
2323
from torchao.quantization.quant_api import quantize_
2424

@@ -58,15 +58,17 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
5858

5959
# Map recipe name to enum
6060
if recipe_name == "fp8_rowwise":
61-
recipe = FP8GroupedMMRecipe.FP8_ROWWISE
61+
recipe = Float8TrainingRecipe.FP8_ROWWISE
6262
elif recipe_name == "mxfp8_rceil":
63-
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL
63+
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL
6464
elif recipe_name == "mxfp8_rceil_wgrad_with_hp":
65-
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
65+
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
6666
else:
6767
raise ValueError(f"Unknown recipe: {recipe_name}")
68+
69+
# Check hardware requirements
6870
if (
69-
recipe == FP8GroupedMMRecipe.FP8_ROWWISE
71+
recipe == Float8TrainingRecipe.FP8_ROWWISE
7072
and torch.cuda.get_device_capability()
7173
!= (
7274
9,
@@ -78,8 +80,8 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
7880
)
7981
return
8082

81-
elif (
82-
recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL
83+
if (
84+
recipe == MXFP8TrainingRecipe.MXFP8_RCEIL
8385
and torch.cuda.get_device_capability()
8486
!= (
8587
10,
@@ -110,7 +112,7 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
110112
model = copy.deepcopy(ref_model)
111113

112114
# Token group alignment size must be 16 for fp8 rowwise training
113-
alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16
115+
alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16
114116
set_token_group_alignment_size_m(alignment_size)
115117

116118
# assert starting params are identical for both models
@@ -125,7 +127,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
125127
return False
126128

127129
# quantize test model
128-
config = MXFP8GroupedMMConfig.from_recipe(recipe)
130+
config = MXFP8TrainingOpConfig.from_recipe(recipe)
129131
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
130132

131133
# inputs

benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
2727
from torchao.prototype.moe_training.config import (
28-
FP8GroupedMMRecipe,
29-
MXFP8GroupedMMConfig,
30-
MXFP8GroupedMMRecipe,
28+
Float8TrainingRecipe,
29+
MXFP8TrainingOpConfig,
30+
MXFP8TrainingRecipe,
3131
)
3232
from torchao.quantization.quant_api import quantize_
3333

@@ -48,15 +48,15 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
4848
assert recipe_name in ["fp8_rowwise", "mxfp8_rceil", "mxfp8_rceil_wgrad_with_hp"]
4949
# Map recipe names to enums
5050
if recipe_name.upper() == "fp8_rowwise":
51-
recipe = FP8GroupedMMRecipe.FP8_ROWWISE
51+
recipe = Float8TrainingRecipe.FP8_ROWWISE
5252
elif recipe_name.upper() == "mxfp8_rceil":
53-
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL
53+
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL
5454
elif recipe_name.upper() == "mxfp8_rceil_wgrad_with_hp":
55-
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
55+
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
5656
else:
5757
raise ValueError(f"Unknown recipe: {recipe_name}")
5858
if (
59-
recipe == FP8GroupedMMRecipe.FP8_ROWWISE
59+
recipe == Float8TrainingRecipe.FP8_ROWWISE
6060
and torch.cuda.get_device_capability()
6161
!= (
6262
9,
@@ -69,7 +69,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
6969
return
7070

7171
elif (
72-
recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL
72+
recipe == MXFP8TrainingRecipe.MXFP8_RCEIL
7373
and torch.cuda.get_device_capability()
7474
!= (
7575
10,
@@ -104,7 +104,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
104104
model = copy.deepcopy(ref_model)
105105

106106
# Token group alignment size must be 16 for fp8 rowwise training
107-
alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16
107+
alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16
108108
set_token_group_alignment_size_m(alignment_size)
109109

110110
# assert starting params are identical for both models
@@ -119,7 +119,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
119119
return False
120120

121121
# quantize test model
122-
config = MXFP8GroupedMMConfig.from_recipe(recipe)
122+
config = MXFP8TrainingOpConfig.from_recipe(recipe)
123123
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
124124

125125
# FSDP2

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919
bench_fwd_microseconds,
2020
profile_fwd_bwd,
2121
)
22-
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
2322
from torchao.prototype.moe_training.config import (
24-
FP8GroupedMMConfig,
25-
FP8GroupedMMRecipe,
26-
MXFP8GroupedMMConfig,
27-
MXFP8GroupedMMRecipe,
23+
Float8TrainingOpConfig,
24+
Float8TrainingRecipe,
25+
MXFP8TrainingOpConfig,
26+
MXFP8TrainingRecipe,
27+
)
28+
from torchao.prototype.moe_training.utils import (
29+
_quantize_then_scaled_grouped_mm,
30+
generate_jagged_offs,
2831
)
29-
from torchao.prototype.moe_training.utils import generate_jagged_offs
3032
from torchao.utils import is_MI300, is_MI350, is_ROCM
3133

3234
device = torch.device("cuda")
@@ -42,7 +44,7 @@
4244
class ExperimentConfig:
4345
high_precision_dtype: torch.dtype
4446
MNKG: tuple[int]
45-
recipe: Union[FP8GroupedMMRecipe, MXFP8GroupedMMRecipe]
47+
recipe: Union[Float8TrainingRecipe, MXFP8TrainingRecipe]
4648

4749

4850
@dataclass(frozen=True)
@@ -92,9 +94,8 @@ def get_configs() -> List[ExperimentConfig]:
9294
(128000, 2048, 7168, 8),
9395
]
9496
recipes = [
95-
FP8GroupedMMRecipe.FP8_ROWWISE,
96-
MXFP8GroupedMMRecipe.MXFP8_RCEIL,
97-
MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
97+
MXFP8TrainingRecipe.MXFP8_RCEIL,
98+
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
9899
]
99100
high_precision_dtypes = [torch.bfloat16]
100101
configs = []
@@ -138,7 +139,7 @@ def run_experiment(
138139
# - the transposed tensor in col-major format with groups along the row dimension,
139140
# which represents the right operand.
140141
token_group_alignment_size = (
141-
16 if config.recipe == FP8GroupedMMRecipe.FP8_ROWWISE else 32
142+
16 if config.recipe == Float8TrainingRecipe.FP8_ROWWISE else 32
142143
)
143144

144145
offs = generate_jagged_offs(G, total_M, multiple_of=token_group_alignment_size)
@@ -170,10 +171,10 @@ def run_experiment(
170171
)
171172

172173
# Create config object from recipe
173-
if isinstance(config.recipe, FP8GroupedMMRecipe):
174-
quant_config = FP8GroupedMMConfig.from_recipe(config.recipe)
174+
if isinstance(config.recipe, Float8TrainingRecipe):
175+
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
175176
else:
176-
quant_config = MXFP8GroupedMMConfig.from_recipe(config.recipe)
177+
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
177178

178179
# fwd_bwd scaled benchmark + profiling
179180
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
@@ -261,7 +262,7 @@ def main(args: argparse.Namespace):
261262
configs = get_configs()
262263
results = []
263264
for config in tqdm(configs):
264-
if config.recipe == FP8GroupedMMRecipe.FP8_ROWWISE:
265+
if config.recipe == Float8TrainingRecipe.FP8_ROWWISE:
265266
if is_ROCM():
266267
if not (is_MI300() or is_MI350()):
267268
logging.warning(
@@ -276,8 +277,8 @@ def main(args: argparse.Namespace):
276277
continue
277278

278279
elif config.recipe in (
279-
MXFP8GroupedMMRecipe.MXFP8_RCEIL,
280-
MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
280+
MXFP8TrainingRecipe.MXFP8_RCEIL,
281+
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
281282
) and torch.cuda.get_device_capability() != (10, 0):
282283
logging.warning(
283284
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"

docs/source/workflows/training.md

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ with torch.inference_mode():
266266
## mxfp8
267267

268268
e2e training with mxfp8 from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
269-
in native PyTorch.
269+
in native PyTorch.
270270

271271
> :warning: We are currently in prototype. Use nightly versions of PyTorch and torchao (or build from source) for best results.
272272
@@ -336,25 +336,19 @@ Below is a toy training loop. For an example real training loop, see our torchti
336336
```python
337337
import torch
338338
from torchao.quantization import quantize_
339-
import torchao.prototype.mx_formats
340-
from torchao.prototype.mx_formats import MXLinearConfig, ScaleCalculationMode
341-
from torchao.quantization.quantize_.common import KernelPreference
342-
343-
# low precision gemm, requires CUDA capability 10.0+
344-
kernel_preference = KernelPreference.AUTO
345-
# or, emulated gemm
346-
# kernel_preference = KernelPreference.EMULATED
347-
348-
scale_calculation_mode = ScaleCalculationMode.FLOOR
349-
# other supported modes: RCEIL, CEIL, EVEN
339+
from torchao.prototype.moe_training.config import MXFP8TrainingOpConfig, MXFP8TrainingRecipe
340+
from torchao.prototype.mx_formats import ScaleCalculationMode
341+
342+
# create config from a recipe
343+
config = MXFP8TrainingOpConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL)
344+
# or manually configure
345+
# config = MXFP8TrainingOpConfig(
346+
# kernel_preference=KernelPreference.AUTO, # or KernelPreference.EMULATED
347+
# scale_calculation_mode=ScaleCalculationMode.RCEIL, # or FLOOR, CEIL, EVEN
348+
# wgrad_with_hp=False, # True to compute grad_weight in high precision
349+
# )
350350

351351
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
352-
config = MXLinearConfig(
353-
elem_dtype=torch.float8_e4m3fn,
354-
block_size=32,
355-
kernel_preference=kernel_preference,
356-
scale_calculation_mode=scale_calculation_mode,
357-
)
358352
quantize_(m, config)
359353
m = torch.compile(m, fullgraph=True)
360354

0 commit comments

Comments
 (0)