Skip to content

Commit c0da952

Browse files
[mxfp8 moe training] _permute_bf16 -> permute_and_pad (pytorch#4083)
[mxfp8 moe training] _permute_bf16 -> _permute_and_pad
1 parent eb64bfb commit c0da952

5 files changed

Lines changed: 11 additions & 11 deletions

File tree

benchmarks/prototype/moe_training/mxfp8/bench_ep_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
permute_mxfp8_fwd_hp_bwd,
3737
unpermute_hp_fwd_mxfp8_bwd,
3838
)
39-
from torchao.prototype.moe_training.ep.permute import _permute_bf16
39+
from torchao.prototype.moe_training.ep.permute import permute_and_pad
4040
from torchao.prototype.moe_training.ep.unpermute import _unpermute_bf16
4141
from torchao.prototype.moe_training.mxfp8_grouped_mm import (
4242
_to_mxfp8_then_scaled_grouped_mm,
@@ -144,7 +144,7 @@ def standard_pipeline(
144144

145145
# Step 2: Permute (BF16)
146146
input_shape, permuted, permuted_indices, num_tokens_per_expert_padded, offsets = (
147-
_permute_bf16(
147+
permute_and_pad(
148148
dispatched,
149149
num_tokens_per_expert_group,
150150
ep_degree,

test/prototype/moe_training/ep/test_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
permute_mxfp8_fwd_hp_bwd,
3838
unpermute_hp_fwd_mxfp8_bwd,
3939
)
40-
from torchao.prototype.moe_training.ep.permute import _permute_bf16
40+
from torchao.prototype.moe_training.ep.permute import permute_and_pad
4141
from torchao.prototype.moe_training.ep.unpermute import _unpermute_bf16
4242
from torchao.prototype.moe_training.mxfp8_grouped_mm import (
4343
_to_mxfp8_then_scaled_grouped_mm,
@@ -72,7 +72,7 @@ def standard_pipeline(
7272

7373
# Step 2: Permute (BF16)
7474
input_shape, permuted, permuted_indices, num_tokens_per_expert_padded, offsets = (
75-
_permute_bf16(
75+
permute_and_pad(
7676
dispatched,
7777
num_tokens_per_expert_group,
7878
ep_degree,

test/prototype/moe_training/ep/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
permute_mxfp8_fwd_hp_bwd,
3737
unpermute_hp_fwd_mxfp8_bwd,
3838
)
39-
from torchao.prototype.moe_training.ep.permute import _permute_bf16
39+
from torchao.prototype.moe_training.ep.permute import permute_and_pad
4040
from torchao.prototype.moe_training.ep.unpermute import _unpermute_bf16
4141
from torchao.prototype.moe_training.mxfp8_grouped_mm import (
4242
_to_mxfp8_then_scaled_grouped_mm,
@@ -181,7 +181,7 @@ def test_full_pipeline(self):
181181
bf16_permuted_indices,
182182
bf16_num_tokens_per_expert_padded,
183183
bf16_group_offsets,
184-
) = _permute_bf16(
184+
) = permute_and_pad(
185185
bf16_dispatched,
186186
num_tokens_per_expert_group,
187187
ep_degree,

test/prototype/moe_training/ep/test_permute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
pytest.skip("Test requires CUDA 12.8+ with SM >= 100", allow_module_level=True)
1212

1313
from torchao.prototype.moe_training.ep import permute_mxfp8_fwd_hp_bwd
14-
from torchao.prototype.moe_training.ep.permute import _permute_bf16
14+
from torchao.prototype.moe_training.ep.permute import permute_and_pad
1515
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1616
from torchao.quantization.utils import compute_error
1717

@@ -57,7 +57,7 @@ def test_mxfp8_permute_forward():
5757
_,
5858
_,
5959
_,
60-
) = _permute_bf16(
60+
) = permute_and_pad(
6161
input_tensor,
6262
num_tokens_per_expert,
6363
ep_degree,

torchao/prototype/moe_training/ep/permute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,16 @@ def backward(
164164
return grad_input, None, None, None, None, None
165165

166166

167-
# Reference impl for testing
168-
def _permute_bf16(
167+
def permute_and_pad(
169168
x: torch.Tensor,
170169
num_tokens_per_expert: torch.Tensor,
171170
ep_degree: int,
172171
num_local_experts: int,
173172
alignment: int,
174173
):
175174
"""
176-
BF16 permute operation used for testing and benchmarking.
175+
Permute token groups from rank-major to expert-major order, and pad group sizes to alignment size,
176+
in preparation for grouped GEMM.
177177
178178
Args:
179179
x: BF16 input tensor

0 commit comments

Comments
 (0)