Skip to content

Commit 10362fc

Browse files
[moe training] default pad_token_groups_for_grouped_mm=False
1 parent 77f23d0 commit 10362fc

3 files changed

Lines changed: 34 additions & 11 deletions

File tree

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,17 @@ def run_experiment(
116116
requires_grad=True,
117117
).transpose(-2, -1)
118118

119-
offs = generate_jagged_offs(G, total_M, multiple_of=1)
119+
# Create config object from recipe
120+
if isinstance(config.recipe, Float8TrainingRecipe):
121+
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
122+
alignment_size = 16 if args.aligned else 1
123+
# TODO: support pad_token_groups_for_grouped_mm option in Float8TrainingOpConfig
124+
else:
125+
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
126+
quant_config.pad_token_groups_for_grouped_mm = not args.aligned
127+
alignment_size = 32 if args.aligned else 1
128+
129+
offs = generate_jagged_offs(G, total_M, multiple_of=alignment_size)
120130

121131
# fwd_bwd bf16 benchmark + profiling
122132
bf16_fwd_bwd_us = bench_fwd_bwd_microseconds(
@@ -138,12 +148,6 @@ def run_experiment(
138148
profile_name="bf16_profile",
139149
)
140150

141-
# Create config object from recipe
142-
if isinstance(config.recipe, Float8TrainingRecipe):
143-
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
144-
else:
145-
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
146-
147151
# fwd_bwd scaled benchmark + profiling
148152
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
149153
_quantize_then_scaled_grouped_mm,
@@ -262,5 +266,11 @@ def main(args: argparse.Namespace):
262266
arg_parser = argparse.ArgumentParser()
263267
arg_parser.add_argument("--compile", action="store_true")
264268
arg_parser.add_argument("--profile", action="store_true")
269+
arg_parser.add_argument(
270+
"--aligned",
271+
action="store_true",
272+
help="If true, token group sizes are pre-aligned, to simulate flow with HybridEP or similar",
273+
)
274+
265275
args = arg_parser.parse_args()
266276
main(args)

test/prototype/moe_training/test_training.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
@pytest.mark.parametrize(
3838
"kernel_preference", [KernelPreference.AUTO, KernelPreference.EMULATED]
3939
)
40+
@pytest.mark.parametrize("token_groups_aligned", [False])
4041
@pytest.mark.parametrize(
4142
"recipe_config",
4243
[
@@ -74,6 +75,7 @@ def test_moe_training(
7475
target_fqns: list[str],
7576
compile: bool,
7677
kernel_preference: KernelPreference,
78+
token_groups_aligned: bool,
7779
recipe_config: dict,
7880
):
7981
(
@@ -105,6 +107,8 @@ def test_moe_training(
105107
pytest.skip(
106108
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
107109
)
110+
if not token_groups_aligned:
111+
pytest.skip("FP8 rowwise doesn't support per group token padding yet")
108112

109113
# MXFP8 hardware path requires SM100
110114
if recipe in (
@@ -118,7 +122,11 @@ def test_moe_training(
118122
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
119123
)
120124

121-
set_token_group_alignment_size_m(1)
125+
alignment_size = 32 if isinstance(recipe, MXFP8TrainingRecipe) else 16
126+
if not token_groups_aligned:
127+
alignment_size = 1
128+
set_token_group_alignment_size_m(alignment_size)
129+
122130
model_args = MoEArgs(
123131
num_experts=8,
124132
num_shared_experts=1,
@@ -154,6 +162,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
154162
else Float8TrainingOpConfig
155163
)
156164
config = config_cls.from_recipe(recipe)
165+
166+
# TODO: support pad_token_groups_for_grouped_mm in Float8TrainingOpConfig
167+
if isinstance(recipe, MXFP8TrainingRecipe) and not token_groups_aligned:
168+
config.pad_token_groups_for_grouped_mm = True
169+
157170
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
158171

159172
# validate that only the experts were converted

torchao/prototype/moe_training/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,23 +117,23 @@ def from_recipe(
117117
out_dtype=torch.bfloat16,
118118
wgrad_with_hp=False,
119119
scale_calculation_mode=ScaleCalculationMode.RCEIL,
120-
pad_token_groups_for_grouped_mm=True,
120+
pad_token_groups_for_grouped_mm=False,
121121
)
122122
elif recipe == MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP:
123123
return cls(
124124
kernel_preference=KernelPreference.AUTO,
125125
out_dtype=torch.bfloat16,
126126
wgrad_with_hp=True,
127127
scale_calculation_mode=ScaleCalculationMode.RCEIL,
128-
pad_token_groups_for_grouped_mm=True,
128+
pad_token_groups_for_grouped_mm=False,
129129
)
130130
elif recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL:
131131
return cls(
132132
kernel_preference=KernelPreference.EMULATED,
133133
out_dtype=torch.bfloat16,
134134
wgrad_with_hp=False,
135135
scale_calculation_mode=ScaleCalculationMode.RCEIL,
136-
pad_token_groups_for_grouped_mm=True,
136+
pad_token_groups_for_grouped_mm=False,
137137
)
138138
else:
139139
raise ValueError(f"Unsupported MXFP8 recipe: {recipe}")

0 commit comments

Comments
 (0)