Skip to content

Commit 481e6bb

Browse files
[mxfp8 moe training] add custom sharding for triton dim0 quant kernel (#3812)
1 parent 3d45dfe commit 481e6bb

2 files changed

Lines changed: 122 additions & 28 deletions

File tree

test/prototype/moe_training/test_tp.py

Lines changed: 114 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import pytest
1818
import torch
1919

20+
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
21+
2022
if torch.version.hip is not None:
2123
pytest.skip(
2224
"ROCm support for MoE quantization is under development",
@@ -32,6 +34,7 @@
3234

3335
try:
3436
from torch.distributed.tensor.parallel import (
37+
ParallelStyle,
3538
PrepareModuleInputOutput,
3639
parallelize_module,
3740
)
@@ -54,22 +57,95 @@
5457
)
5558
from torchao.quantization.quant_api import quantize_
5659

60+
from .reference_moe import MoE, MoEArgs, set_token_group_alignment_size_m
5761
from .testing_utils import _validate_model_conversion
5862

59-
# this test requires torchtitan
60-
try:
61-
from torchtitan.distributed import NoParallel
62-
from torchtitan.distributed.expert_parallel import (
63-
ExpertParallel,
64-
ExpertTensorParallel,
65-
TensorParallel,
66-
)
67-
from torchtitan.models.moe import MoE, MoEArgs
68-
from torchtitan.models.moe.utils import set_token_group_alignment_size_m
69-
except ImportError:
70-
pytest.skip(
71-
"torchtitan not installed, skipping MoE tests.", allow_module_level=True
72-
)
63+
64+
class NoParallel(ParallelStyle):
65+
"""Placeholder for no parallelization."""
66+
67+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
68+
return module
69+
70+
71+
class TensorParallel(ParallelStyle):
72+
"""Tensor parallelism for MoE layers."""
73+
74+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
75+
from torch.distributed.tensor import distribute_module, distribute_tensor
76+
77+
def _partition_fn(name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None:
78+
# Shard w1 and w3 on dim 1 (hidden_dim), w2 on dim 2 (hidden_dim)
79+
for param_name, param in mod.named_parameters(recurse=False):
80+
if param_name in ("w1", "w3"):
81+
dist_param = nn.Parameter(
82+
distribute_tensor(param, device_mesh, [Shard(1)])
83+
)
84+
elif param_name == "w2":
85+
dist_param = nn.Parameter(
86+
distribute_tensor(param, device_mesh, [Shard(2)])
87+
)
88+
else:
89+
dist_param = nn.Parameter(
90+
distribute_tensor(param, device_mesh, [Replicate()])
91+
)
92+
mod.register_parameter(param_name, dist_param)
93+
94+
return distribute_module(module, device_mesh, partition_fn=_partition_fn)
95+
96+
97+
class ExpertParallel(ParallelStyle):
98+
"""Expert parallelism for MoE layers."""
99+
100+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
101+
from torch.distributed.tensor import distribute_tensor
102+
from torch.distributed.tensor.parallel import distribute_module
103+
104+
def _partition_fn(name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None:
105+
# Shard experts along the expert dimension (dim 0)
106+
for param_name, param in mod.named_parameters(recurse=False):
107+
dist_param = nn.Parameter(
108+
distribute_tensor(param, device_mesh, [Shard(0)])
109+
)
110+
mod.register_parameter(param_name, dist_param)
111+
112+
return distribute_module(module, device_mesh, partition_fn=_partition_fn)
113+
114+
115+
class ExpertTensorParallel(ParallelStyle):
116+
"""Combined expert and tensor parallelism for MoE layers."""
117+
118+
def __init__(self, tp_mesh: DeviceMesh, ep_mesh: DeviceMesh):
119+
super().__init__()
120+
self.tp_mesh = tp_mesh
121+
self.ep_mesh = ep_mesh
122+
123+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
124+
from torch.distributed.tensor import distribute_tensor
125+
from torch.distributed.tensor.parallel import distribute_module
126+
127+
def _partition_fn(name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None:
128+
# Shard along expert dim (EP) and hidden dim (TP)
129+
for param_name, param in mod.named_parameters(recurse=False):
130+
if param_name in ("w1", "w3"):
131+
# Shard on expert dim and hidden_dim
132+
dist_param = nn.Parameter(
133+
distribute_tensor(param, device_mesh, [Shard(0), Shard(1)])
134+
)
135+
elif param_name == "w2":
136+
# Shard on expert dim and hidden_dim (dim 2 for w2)
137+
dist_param = nn.Parameter(
138+
distribute_tensor(param, device_mesh, [Shard(0), Shard(2)])
139+
)
140+
else:
141+
dist_param = nn.Parameter(
142+
distribute_tensor(
143+
param, device_mesh, [Replicate(), Replicate()]
144+
)
145+
)
146+
mod.register_parameter(param_name, dist_param)
147+
148+
return distribute_module(module, device_mesh, partition_fn=_partition_fn)
73149

74150

75151
@pytest.fixture(scope="module")
@@ -80,7 +156,7 @@ def device_mesh_1d() -> DeviceMesh:
80156
"""
81157
rank = int(os.environ["RANK"])
82158
world_size = int(os.environ["WORLD_SIZE"])
83-
device_mesh = init_device_mesh("cuda", (world_size,))
159+
device_mesh = init_device_mesh("cuda", (world_size,), mesh_dim_names=("tp",))
84160
torch.manual_seed(1)
85161
torch.cuda.set_device(rank)
86162

@@ -96,6 +172,9 @@ def device_mesh_1d() -> DeviceMesh:
96172
["experts,shared_experts"],
97173
],
98174
)
175+
@pytest.mark.parametrize(
176+
"kernel_preference", [KernelPreference.AUTO, KernelPreference.EMULATED]
177+
)
99178
@pytest.mark.parametrize("compile", [False, True])
100179
@pytest.mark.parametrize(
101180
"recipe_config",
@@ -114,10 +193,18 @@ def device_mesh_1d() -> DeviceMesh:
114193
"min_input_grad_sqnr": 29.0,
115194
"min_param_grad_sqnr": 21.0,
116195
},
196+
{
197+
"recipe": MoEScalingType.MXFP8_WGRAD_WITH_HP,
198+
"group_alignment_size": 32,
199+
"min_out_sqnr": 28.0,
200+
"min_input_grad_sqnr": 29.0,
201+
"min_param_grad_sqnr": 25.0,
202+
},
117203
],
118204
)
119205
def test_moe_training_tp(
120206
target_fqns: list[str],
207+
kernel_preference: KernelPreference,
121208
compile: bool,
122209
recipe_config: dict,
123210
device_mesh_1d: DeviceMesh,
@@ -144,23 +231,22 @@ def test_moe_training_tp(
144231
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
145232
)
146233

147-
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
148-
10,
149-
0,
150-
):
151-
pytest.skip(
152-
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
153-
)
234+
elif recipe in (MoEScalingType.MXFP8, MoEScalingType.MXFP8_WGRAD_WITH_HP):
235+
emulated = kernel_preference == KernelPreference.EMULATED
236+
if not emulated and torch.cuda.get_device_capability() != (
237+
10,
238+
0,
239+
):
240+
pytest.skip(
241+
f"Non-emulated mode only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
242+
)
243+
if emulated and compile:
244+
pytest.skip("MXFP8 emulated mode does not support torch.compile")
154245

155246
# set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned)
156247
# or quantization ops (mxfp8 scaling groups are size 1x32)
157248
set_token_group_alignment_size_m(group_alignment_size)
158249

159-
# define model args
160-
model_args = MoEArgs(
161-
num_experts=8,
162-
)
163-
164250
# define model args
165251
model_args = MoEArgs(
166252
num_experts=8,
@@ -189,7 +275,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
189275
return False
190276

191277
# quantize test model
192-
config = MoETrainingConfig(recipe)
278+
config = MoETrainingConfig(recipe, kernel_preference=kernel_preference)
193279
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
194280

195281
# validate that only the experts were converted

torchao/prototype/mx_formats/kernels.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,14 @@ def triton_to_mxfp8_dim1(
11341134
col_scale.view(torch.float8_e8m0fnu).squeeze(-1),
11351135
)
11361136

1137+
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim0.default)
1138+
def custom_triton_to_mxfp8_dim0_sharding(x, inner_block_size=32):
1139+
replicate = ([Replicate(), Replicate()], [Replicate(), None])
1140+
shard_dim0 = ([Shard(0), Shard(0)], [Shard(0), None])
1141+
shard_dim1 = ([Shard(1), Shard(1)], [Shard(1), None])
1142+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
1143+
return acceptable_shardings
1144+
11371145
@register_sharding(torch.ops.torchao.triton_to_mxfp8_dim1.default)
11381146
def custom_triton_to_mxfp8_dim1_sharding(x, inner_block_size=32):
11391147
replicate = ([Replicate(), Replicate()], [Replicate(), None])

0 commit comments

Comments
 (0)