Skip to content

Commit 9bdc0ca

Browse files
Fix backward return count mismatch in _Float8GroupedMM (#3956)
1 parent a4ae9cc commit 9bdc0ca

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

torchao/prototype/moe_training/fp8_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,4 +254,4 @@ def backward(ctx, grad_output: torch.Tensor):
254254
out_dtype=out_dtype,
255255
use_fast_accum=True,
256256
)
257-
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
257+
return grad_A, grad_B.transpose(-2, -1), None, None, None

0 commit comments

Comments
 (0)