Skip to content

Commit b65b280

Browse files
committed
bitwise?
1 parent cbe9d96 commit b65b280

1 file changed

Lines changed: 17 additions & 60 deletions

File tree

torchao/testing/training/dtensor_utils.py

Lines changed: 17 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -192,64 +192,21 @@ def _test_lowp_mlp_tensor_parallelism_base(
192192
global_out = toy_model_fp8(x)
193193
global_out.backward(go)
194194

195-
if is_mxfp8:
196-
# MXFP8 emulated dim1 quantization transposes and re-contiguifies the
197-
# activation (a.t().contiguous()), which can change how elements land
198-
# in 32-element blocks depending on whether the input was sharded or
199-
# not. This produces small numerical differences, so use relaxed tols.
200-
atol, rtol = 0.15, 0.05
201-
202-
torch.testing.assert_close(tp_out, global_out, atol=atol, rtol=rtol)
203-
torch.testing.assert_close(
204-
sp_out.full_tensor(), global_out, atol=atol, rtol=rtol
205-
)
206-
torch.testing.assert_close(
207-
tp_model.ffn.w1.weight.grad,
208-
sp_model.ffn.w1.weight.grad,
209-
atol=atol,
210-
rtol=rtol,
211-
)
212-
torch.testing.assert_close(
213-
tp_model.ffn.out_proj.weight.grad,
214-
sp_model.ffn.out_proj.weight.grad,
215-
atol=atol,
216-
rtol=rtol,
217-
)
218-
219-
sp_out2 = sp_model2(x_sp_input)
220-
sp_out2.backward(go_sp)
221-
222-
torch.testing.assert_close(
223-
sp_out2.full_tensor(), global_out, atol=atol, rtol=rtol
224-
)
225-
torch.testing.assert_close(
226-
tp_model.ffn.w1.weight.grad,
227-
sp_model2.ffn.w1.weight.grad,
228-
atol=atol,
229-
rtol=rtol,
230-
)
231-
torch.testing.assert_close(
232-
tp_model.ffn.out_proj.weight.grad,
233-
sp_model2.ffn.out_proj.weight.grad,
234-
atol=atol,
235-
rtol=rtol,
236-
)
237-
else:
238-
torch.testing.assert_close(tp_out, global_out)
239-
torch.testing.assert_close(sp_out.full_tensor(), global_out)
240-
torch.testing.assert_close(
241-
tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad
242-
)
243-
torch.testing.assert_close(
244-
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
245-
)
195+
torch.testing.assert_close(tp_out, global_out)
196+
torch.testing.assert_close(sp_out.full_tensor(), global_out)
197+
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
198+
torch.testing.assert_close(
199+
tp_model.ffn.out_proj.weight.grad,
200+
sp_model.ffn.out_proj.weight.grad,
201+
)
246202

247-
sp_out2 = sp_model2(x_sp_input)
248-
sp_out2.backward(go_sp)
249-
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
250-
torch.testing.assert_close(
251-
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
252-
)
253-
torch.testing.assert_close(
254-
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
255-
)
203+
sp_out2 = sp_model2(x_sp_input)
204+
sp_out2.backward(go_sp)
205+
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
206+
torch.testing.assert_close(
207+
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
208+
)
209+
torch.testing.assert_close(
210+
tp_model.ffn.out_proj.weight.grad,
211+
sp_model2.ffn.out_proj.weight.grad,
212+
)

0 commit comments

Comments
 (0)