Skip to content

Commit 04b49da

Browse files
committed
Update
[ghstack-poisoned]
1 parent 4cc0095 commit 04b49da

2 files changed

Lines changed: 5 additions & 1 deletion

File tree

test/prototype/moe_training/test_kernels.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,10 @@ def test_cuda_mx_3d_cutedsl_numerics(E, N, K, input_dtype, scaling_mode, variant
449449
s_rows, s_cols = x_t.shape[-1], x_t.shape[-2] // block_size
450450
s_logical = (
451451
torch.stack(
452-
[from_blocked(s[e], s_rows, s_cols).view(torch.uint8) for e in range(E)],
452+
[
453+
from_blocked(s[e], s_rows, s_cols).view(torch.uint8)
454+
for e in range(E)
455+
],
453456
dim=0,
454457
)
455458
.view(torch.float8_e8m0fnu)

torchao/prototype/moe_training/mxfp8_grouped_mm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
mxfp8_quantize_2d_1x32_cutedsl,
1818
mxfp8_quantize_cuda_3d,
1919
triton_mx_block_rearrange_2d_K_groups,
20+
triton_mx_block_rearrange_per_group_3d,
2021
)
2122
from torchao.prototype.moe_training.utils import (
2223
conditional_nostrict_trace,

0 commit comments

Comments
 (0)