1717import pytest
1818import torch
1919
20+ from torchao .quantization .quantize_ .common .kernel_preference import KernelPreference
21+
2022if torch .version .hip is not None :
2123 pytest .skip (
2224 "ROCm support for MoE quantization is under development" ,
3234
3335try :
3436 from torch .distributed .tensor .parallel import (
37+ ParallelStyle ,
3538 PrepareModuleInputOutput ,
3639 parallelize_module ,
3740 )
5457)
5558from torchao .quantization .quant_api import quantize_
5659
60+ from .reference_moe import MoE , MoEArgs , set_token_group_alignment_size_m
5761from .testing_utils import _validate_model_conversion
5862
59- # this test requires torchtitan
60- try :
61- from torchtitan .distributed import NoParallel
62- from torchtitan .distributed .expert_parallel import (
63- ExpertParallel ,
64- ExpertTensorParallel ,
65- TensorParallel ,
66- )
67- from torchtitan .models .moe import MoE , MoEArgs
68- from torchtitan .models .moe .utils import set_token_group_alignment_size_m
69- except ImportError :
70- pytest .skip (
71- "torchtitan not installed, skipping MoE tests." , allow_module_level = True
72- )
63+
64+ class NoParallel (ParallelStyle ):
65+ """Placeholder for no parallelization."""
66+
67+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
68+ return module
69+
70+
71+ class TensorParallel (ParallelStyle ):
72+ """Tensor parallelism for MoE layers."""
73+
74+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
75+ from torch .distributed .tensor import distribute_module , distribute_tensor
76+
77+ def _partition_fn (name : str , mod : nn .Module , device_mesh : DeviceMesh ) -> None :
78+ # Shard w1 and w3 on dim 1 (hidden_dim), w2 on dim 2 (hidden_dim)
79+ for param_name , param in mod .named_parameters (recurse = False ):
80+ if param_name in ("w1" , "w3" ):
81+ dist_param = nn .Parameter (
82+ distribute_tensor (param , device_mesh , [Shard (1 )])
83+ )
84+ elif param_name == "w2" :
85+ dist_param = nn .Parameter (
86+ distribute_tensor (param , device_mesh , [Shard (2 )])
87+ )
88+ else :
89+ dist_param = nn .Parameter (
90+ distribute_tensor (param , device_mesh , [Replicate ()])
91+ )
92+ mod .register_parameter (param_name , dist_param )
93+
94+ return distribute_module (module , device_mesh , partition_fn = _partition_fn )
95+
96+
97+ class ExpertParallel (ParallelStyle ):
98+ """Expert parallelism for MoE layers."""
99+
100+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
101+ from torch .distributed .tensor import distribute_tensor
102+ from torch .distributed .tensor .parallel import distribute_module
103+
104+ def _partition_fn (name : str , mod : nn .Module , device_mesh : DeviceMesh ) -> None :
105+ # Shard experts along the expert dimension (dim 0)
106+ for param_name , param in mod .named_parameters (recurse = False ):
107+ dist_param = nn .Parameter (
108+ distribute_tensor (param , device_mesh , [Shard (0 )])
109+ )
110+ mod .register_parameter (param_name , dist_param )
111+
112+ return distribute_module (module , device_mesh , partition_fn = _partition_fn )
113+
114+
115+ class ExpertTensorParallel (ParallelStyle ):
116+ """Combined expert and tensor parallelism for MoE layers."""
117+
118+ def __init__ (self , tp_mesh : DeviceMesh , ep_mesh : DeviceMesh ):
119+ super ().__init__ ()
120+ self .tp_mesh = tp_mesh
121+ self .ep_mesh = ep_mesh
122+
123+ def _apply (self , module : nn .Module , device_mesh : DeviceMesh ) -> nn .Module :
124+ from torch .distributed .tensor import distribute_tensor
125+ from torch .distributed .tensor .parallel import distribute_module
126+
127+ def _partition_fn (name : str , mod : nn .Module , device_mesh : DeviceMesh ) -> None :
128+ # Shard along expert dim (EP) and hidden dim (TP)
129+ for param_name , param in mod .named_parameters (recurse = False ):
130+ if param_name in ("w1" , "w3" ):
131+ # Shard on expert dim and hidden_dim
132+ dist_param = nn .Parameter (
133+ distribute_tensor (param , device_mesh , [Shard (0 ), Shard (1 )])
134+ )
135+ elif param_name == "w2" :
136+ # Shard on expert dim and hidden_dim (dim 2 for w2)
137+ dist_param = nn .Parameter (
138+ distribute_tensor (param , device_mesh , [Shard (0 ), Shard (2 )])
139+ )
140+ else :
141+ dist_param = nn .Parameter (
142+ distribute_tensor (
143+ param , device_mesh , [Replicate (), Replicate ()]
144+ )
145+ )
146+ mod .register_parameter (param_name , dist_param )
147+
148+ return distribute_module (module , device_mesh , partition_fn = _partition_fn )
73149
74150
75151@pytest .fixture (scope = "module" )
@@ -80,7 +156,7 @@ def device_mesh_1d() -> DeviceMesh:
80156 """
81157 rank = int (os .environ ["RANK" ])
82158 world_size = int (os .environ ["WORLD_SIZE" ])
83- device_mesh = init_device_mesh ("cuda" , (world_size ,))
159+ device_mesh = init_device_mesh ("cuda" , (world_size ,), mesh_dim_names = ( "tp" ,) )
84160 torch .manual_seed (1 )
85161 torch .cuda .set_device (rank )
86162
@@ -96,6 +172,9 @@ def device_mesh_1d() -> DeviceMesh:
96172 ["experts,shared_experts" ],
97173 ],
98174)
175+ @pytest .mark .parametrize (
176+ "kernel_preference" , [KernelPreference .AUTO , KernelPreference .EMULATED ]
177+ )
99178@pytest .mark .parametrize ("compile" , [False , True ])
100179@pytest .mark .parametrize (
101180 "recipe_config" ,
@@ -114,10 +193,18 @@ def device_mesh_1d() -> DeviceMesh:
114193 "min_input_grad_sqnr" : 29.0 ,
115194 "min_param_grad_sqnr" : 21.0 ,
116195 },
196+ {
197+ "recipe" : MoEScalingType .MXFP8_WGRAD_WITH_HP ,
198+ "group_alignment_size" : 32 ,
199+ "min_out_sqnr" : 28.0 ,
200+ "min_input_grad_sqnr" : 29.0 ,
201+ "min_param_grad_sqnr" : 25.0 ,
202+ },
117203 ],
118204)
119205def test_moe_training_tp (
120206 target_fqns : list [str ],
207+ kernel_preference : KernelPreference ,
121208 compile : bool ,
122209 recipe_config : dict ,
123210 device_mesh_1d : DeviceMesh ,
@@ -144,23 +231,22 @@ def test_moe_training_tp(
144231 f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found { torch .cuda .get_device_capability ()} "
145232 )
146233
147- elif recipe == MoEScalingType .MXFP8 and torch .cuda .get_device_capability () != (
148- 10 ,
149- 0 ,
150- ):
151- pytest .skip (
152- f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found { torch .cuda .get_device_capability ()} "
153- )
234+ elif recipe in (MoEScalingType .MXFP8 , MoEScalingType .MXFP8_WGRAD_WITH_HP ):
235+ emulated = kernel_preference == KernelPreference .EMULATED
236+ if not emulated and torch .cuda .get_device_capability () != (
237+ 10 ,
238+ 0 ,
239+ ):
240+ pytest .skip (
241+ f"Non-emulated mode only supported on compute capability 10.0 and found { torch .cuda .get_device_capability ()} "
242+ )
243+ if emulated and compile :
244+ pytest .skip ("MXFP8 emulated mode does not support torch.compile" )
154245
155246 # set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned)
156247 # or quantization ops (mxfp8 scaling groups are size 1x32)
157248 set_token_group_alignment_size_m (group_alignment_size )
158249
159- # define model args
160- model_args = MoEArgs (
161- num_experts = 8 ,
162- )
163-
164250 # define model args
165251 model_args = MoEArgs (
166252 num_experts = 8 ,
@@ -189,7 +275,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
189275 return False
190276
191277 # quantize test model
192- config = MoETrainingConfig (recipe )
278+ config = MoETrainingConfig (recipe , kernel_preference = kernel_preference )
193279 quantize_ (model , config = config , filter_fn = moe_module_filter_fn )
194280
195281 # validate that only the experts were converted
0 commit comments