@@ -192,64 +192,21 @@ def _test_lowp_mlp_tensor_parallelism_base(
192192 global_out = toy_model_fp8 (x )
193193 global_out .backward (go )
194194
195- if is_mxfp8 :
196- # MXFP8 emulated dim1 quantization transposes and re-contiguifies the
197- # activation (a.t().contiguous()), which can change how elements land
198- # in 32-element blocks depending on whether the input was sharded or
199- # not. This produces small numerical differences, so use relaxed tols.
200- atol , rtol = 0.15 , 0.05
201-
202- torch .testing .assert_close (tp_out , global_out , atol = atol , rtol = rtol )
203- torch .testing .assert_close (
204- sp_out .full_tensor (), global_out , atol = atol , rtol = rtol
205- )
206- torch .testing .assert_close (
207- tp_model .ffn .w1 .weight .grad ,
208- sp_model .ffn .w1 .weight .grad ,
209- atol = atol ,
210- rtol = rtol ,
211- )
212- torch .testing .assert_close (
213- tp_model .ffn .out_proj .weight .grad ,
214- sp_model .ffn .out_proj .weight .grad ,
215- atol = atol ,
216- rtol = rtol ,
217- )
218-
219- sp_out2 = sp_model2 (x_sp_input )
220- sp_out2 .backward (go_sp )
221-
222- torch .testing .assert_close (
223- sp_out2 .full_tensor (), global_out , atol = atol , rtol = rtol
224- )
225- torch .testing .assert_close (
226- tp_model .ffn .w1 .weight .grad ,
227- sp_model2 .ffn .w1 .weight .grad ,
228- atol = atol ,
229- rtol = rtol ,
230- )
231- torch .testing .assert_close (
232- tp_model .ffn .out_proj .weight .grad ,
233- sp_model2 .ffn .out_proj .weight .grad ,
234- atol = atol ,
235- rtol = rtol ,
236- )
237- else :
238- torch .testing .assert_close (tp_out , global_out )
239- torch .testing .assert_close (sp_out .full_tensor (), global_out )
240- torch .testing .assert_close (
241- tp_model .ffn .w1 .weight .grad , sp_model .ffn .w1 .weight .grad
242- )
243- torch .testing .assert_close (
244- tp_model .ffn .out_proj .weight .grad , sp_model .ffn .out_proj .weight .grad
245- )
195+ torch .testing .assert_close (tp_out , global_out )
196+ torch .testing .assert_close (sp_out .full_tensor (), global_out )
197+ torch .testing .assert_close (tp_model .ffn .w1 .weight .grad , sp_model .ffn .w1 .weight .grad )
198+ torch .testing .assert_close (
199+ tp_model .ffn .out_proj .weight .grad ,
200+ sp_model .ffn .out_proj .weight .grad ,
201+ )
246202
247- sp_out2 = sp_model2 (x_sp_input )
248- sp_out2 .backward (go_sp )
249- torch .testing .assert_close (sp_out2 .full_tensor (), global_out )
250- torch .testing .assert_close (
251- tp_model .ffn .w1 .weight .grad , sp_model2 .ffn .w1 .weight .grad
252- )
253- torch .testing .assert_close (
254- tp_model .ffn .out_proj .weight .grad , sp_model2 .ffn .out_proj .weight .grad
255- )
203+ sp_out2 = sp_model2 (x_sp_input )
204+ sp_out2 .backward (go_sp )
205+ torch .testing .assert_close (sp_out2 .full_tensor (), global_out )
206+ torch .testing .assert_close (
207+ tp_model .ffn .w1 .weight .grad , sp_model2 .ffn .w1 .weight .grad
208+ )
209+ torch .testing .assert_close (
210+ tp_model .ffn .out_proj .weight .grad ,
211+ sp_model2 .ffn .out_proj .weight .grad ,
212+ )
0 commit comments