1919 bench_fwd_microseconds ,
2020 profile_fwd_bwd ,
2121)
22- from torchao .prototype .moe_training import _quantize_then_scaled_grouped_mm
2322from torchao .prototype .moe_training .config import (
24- FP8GroupedMMConfig ,
25- FP8GroupedMMRecipe ,
26- MXFP8GroupedMMConfig ,
27- MXFP8GroupedMMRecipe ,
23+ Float8TrainingOpConfig ,
24+ Float8TrainingRecipe ,
25+ MXFP8TrainingOpConfig ,
26+ MXFP8TrainingRecipe ,
27+ )
28+ from torchao .prototype .moe_training .utils import (
29+ _quantize_then_scaled_grouped_mm ,
30+ generate_jagged_offs ,
2831)
29- from torchao .prototype .moe_training .utils import generate_jagged_offs
3032from torchao .utils import is_MI300 , is_MI350 , is_ROCM
3133
3234device = torch .device ("cuda" )
4244class ExperimentConfig :
4345 high_precision_dtype : torch .dtype
4446 MNKG : tuple [int ]
45- recipe : Union [FP8GroupedMMRecipe , MXFP8GroupedMMRecipe ]
47+ recipe : Union [Float8TrainingRecipe , MXFP8TrainingRecipe ]
4648
4749
4850@dataclass (frozen = True )
@@ -92,9 +94,8 @@ def get_configs() -> List[ExperimentConfig]:
9294 (128000 , 2048 , 7168 , 8 ),
9395 ]
9496 recipes = [
95- FP8GroupedMMRecipe .FP8_ROWWISE ,
96- MXFP8GroupedMMRecipe .MXFP8_RCEIL ,
97- MXFP8GroupedMMRecipe .MXFP8_RCEIL_WGRAD_WITH_HP ,
97+ MXFP8TrainingRecipe .MXFP8_RCEIL ,
98+ MXFP8TrainingRecipe .MXFP8_RCEIL_WGRAD_WITH_HP ,
9899 ]
99100 high_precision_dtypes = [torch .bfloat16 ]
100101 configs = []
@@ -138,7 +139,7 @@ def run_experiment(
138139 # - the transposed tensor in col-major format with groups along the row dimension,
139140 # which represents the right operand.
140141 token_group_alignment_size = (
141- 16 if config .recipe == FP8GroupedMMRecipe .FP8_ROWWISE else 32
142+ 16 if config .recipe == Float8TrainingRecipe .FP8_ROWWISE else 32
142143 )
143144
144145 offs = generate_jagged_offs (G , total_M , multiple_of = token_group_alignment_size )
@@ -170,10 +171,10 @@ def run_experiment(
170171 )
171172
172173 # Create config object from recipe
173- if isinstance (config .recipe , FP8GroupedMMRecipe ):
174- quant_config = FP8GroupedMMConfig .from_recipe (config .recipe )
174+ if isinstance (config .recipe , Float8TrainingRecipe ):
175+ quant_config = Float8TrainingOpConfig .from_recipe (config .recipe )
175176 else :
176- quant_config = MXFP8GroupedMMConfig .from_recipe (config .recipe )
177+ quant_config = MXFP8TrainingOpConfig .from_recipe (config .recipe )
177178
178179 # fwd_bwd scaled benchmark + profiling
179180 scaled_fwd_bwd_us = bench_fwd_bwd_microseconds (
@@ -261,7 +262,7 @@ def main(args: argparse.Namespace):
261262 configs = get_configs ()
262263 results = []
263264 for config in tqdm (configs ):
264- if config .recipe == FP8GroupedMMRecipe .FP8_ROWWISE :
265+ if config .recipe == Float8TrainingRecipe .FP8_ROWWISE :
265266 if is_ROCM ():
266267 if not (is_MI300 () or is_MI350 ()):
267268 logging .warning (
@@ -276,8 +277,8 @@ def main(args: argparse.Namespace):
276277 continue
277278
278279 elif config .recipe in (
279- MXFP8GroupedMMRecipe .MXFP8_RCEIL ,
280- MXFP8GroupedMMRecipe .MXFP8_RCEIL_WGRAD_WITH_HP ,
280+ MXFP8TrainingRecipe .MXFP8_RCEIL ,
281+ MXFP8TrainingRecipe .MXFP8_RCEIL_WGRAD_WITH_HP ,
281282 ) and torch .cuda .get_device_capability () != (10 , 0 ):
282283 logging .warning (
283284 f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found { torch .cuda .get_device_capability ()} "
0 commit comments