Skip to content

Commit b6bf1a5

Browse files
[mxfp8 moe training] default pad_token_groups_for_grouped_mm to False
1 parent eb64bfb commit b6bf1a5

4 files changed

Lines changed: 36 additions & 12 deletions

File tree

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,17 @@ def run_experiment(
116116
requires_grad=True,
117117
).transpose(-2, -1)
118118

119-
offs = generate_jagged_offs(G, total_M, multiple_of=1)
119+
# Create config object from recipe
120+
if isinstance(config.recipe, Float8TrainingRecipe):
121+
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
122+
alignment_size = 16 if args.aligned else 1
123+
# TODO: support pad_token_groups_for_grouped_mm option in Float8TrainingOpConfig
124+
else:
125+
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
126+
quant_config.pad_token_groups_for_grouped_mm = not args.aligned
127+
alignment_size = 32 if args.aligned else 1
128+
129+
offs = generate_jagged_offs(G, total_M, multiple_of=alignment_size)
120130

121131
# fwd_bwd bf16 benchmark + profiling
122132
bf16_fwd_bwd_us = bench_fwd_bwd_microseconds(
@@ -138,12 +148,6 @@ def run_experiment(
138148
profile_name="bf16_profile",
139149
)
140150

141-
# Create config object from recipe
142-
if isinstance(config.recipe, Float8TrainingRecipe):
143-
quant_config = Float8TrainingOpConfig.from_recipe(config.recipe)
144-
else:
145-
quant_config = MXFP8TrainingOpConfig.from_recipe(config.recipe)
146-
147151
# fwd_bwd scaled benchmark + profiling
148152
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
149153
_quantize_then_scaled_grouped_mm,
@@ -262,5 +266,11 @@ def main(args: argparse.Namespace):
262266
arg_parser = argparse.ArgumentParser()
263267
arg_parser.add_argument("--compile", action="store_true")
264268
arg_parser.add_argument("--profile", action="store_true")
269+
arg_parser.add_argument(
270+
"--aligned",
271+
action="store_true",
272+
help="If true, token group sizes are pre-aligned, to simulate flow with HybridEP or similar",
273+
)
274+
265275
args = arg_parser.parse_args()
266276
main(args)

test/prototype/moe_training/test_training.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
@pytest.mark.parametrize(
3838
"kernel_preference", [KernelPreference.AUTO, KernelPreference.EMULATED]
3939
)
40+
@pytest.mark.parametrize("token_groups_aligned", [False])
4041
@pytest.mark.parametrize(
4142
"recipe_config",
4243
[
@@ -74,6 +75,7 @@ def test_moe_training(
7475
target_fqns: list[str],
7576
compile: bool,
7677
kernel_preference: KernelPreference,
78+
token_groups_aligned: bool,
7779
recipe_config: dict,
7880
):
7981
(
@@ -110,6 +112,8 @@ def test_moe_training(
110112
pytest.skip(
111113
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
112114
)
115+
if not token_groups_aligned:
116+
pytest.skip("FP8 rowwise doesn't support per group token padding yet")
113117

114118
# MXFP8 hardware path requires SM100
115119
if recipe in (
@@ -123,7 +127,11 @@ def test_moe_training(
123127
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
124128
)
125129

126-
set_token_group_alignment_size_m(1)
130+
alignment_size = 32 if isinstance(recipe, MXFP8TrainingRecipe) else 16
131+
if not token_groups_aligned:
132+
alignment_size = 1
133+
set_token_group_alignment_size_m(alignment_size)
134+
127135
model_args = MoEArgs(
128136
num_experts=8,
129137
num_shared_experts=1,
@@ -159,6 +167,11 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
159167
else Float8TrainingOpConfig
160168
)
161169
config = config_cls.from_recipe(recipe)
170+
171+
# TODO: support pad_token_groups_for_grouped_mm in Float8TrainingOpConfig
172+
if isinstance(recipe, MXFP8TrainingRecipe) and not token_groups_aligned:
173+
config.pad_token_groups_for_grouped_mm = True
174+
162175
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
163176

164177
# validate that only the experts were converted

torchao/prototype/moe_training/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,23 @@ def from_recipe(
119119
out_dtype=torch.bfloat16,
120120
wgrad_with_hp=False,
121121
scale_calculation_mode=ScaleCalculationMode.RCEIL,
122-
pad_token_groups_for_grouped_mm=True,
122+
pad_token_groups_for_grouped_mm=False,
123123
)
124124
elif recipe == MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP:
125125
return cls(
126126
kernel_preference=KernelPreference.AUTO,
127127
out_dtype=torch.bfloat16,
128128
wgrad_with_hp=True,
129129
scale_calculation_mode=ScaleCalculationMode.RCEIL,
130-
pad_token_groups_for_grouped_mm=True,
130+
pad_token_groups_for_grouped_mm=False,
131131
)
132132
elif recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL:
133133
return cls(
134134
kernel_preference=KernelPreference.EMULATED,
135135
out_dtype=torch.bfloat16,
136136
wgrad_with_hp=False,
137137
scale_calculation_mode=ScaleCalculationMode.RCEIL,
138-
pad_token_groups_for_grouped_mm=True,
138+
pad_token_groups_for_grouped_mm=False,
139139
)
140140
else:
141141
raise ValueError(f"Unsupported MXFP8 recipe: {recipe}")

torchao/prototype/moe_training/mxfp8_grouped_mm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
triton_mx_block_rearrange_per_group_3d,
2020
)
2121
from torchao.prototype.moe_training.utils import (
22+
conditional_nostrict_trace,
2223
pad_token_groups,
2324
unpad_token_groups,
2425
)
@@ -77,7 +78,7 @@ def _validate_grouped_mm_input_act(
7778

7879

7980
# Aliases for convenience/clarity
80-
# @conditional_nostrict_trace
81+
@conditional_nostrict_trace
8182
def _to_mxfp8_then_scaled_grouped_mm(
8283
A: torch.Tensor,
8384
B_t: torch.Tensor,

0 commit comments

Comments
 (0)