Skip to content

Commit f6f29f6

Browse files
[mxfp8 moe training] register sharding rules for cutedsl 2d quant kernel (#4178)
1 parent a6be48f commit f6f29f6

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

  • torchao/prototype/moe_training/kernels/mxfp8

torchao/prototype/moe_training/kernels/mxfp8/quant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import torch
1010
from torch import Tensor
11+
from torch.distributed.tensor import Replicate, Shard
12+
from torch.distributed.tensor.experimental import register_sharding
1113
from torch.utils._triton import has_triton
1214

1315
from torchao.prototype.moe_training.kernels.mxfp8.cute_utils import (
@@ -975,6 +977,20 @@ def _fake_mxfp8_quantize_2d_cutedsl_custom_op(
975977
return q_data, scales
976978

977979

980+
if _mxfp8_cutedsl_kernels_available:
981+
982+
@register_sharding(torch.ops.torchao.mxfp8_quantize_2d_cutedsl.default)
983+
def custom_sharding_for_cutedsl_mxfp8_dim0_kernel(
984+
x, block_size=32, scaling_mode: str = "rceil", stage_count: int = 2
985+
):
986+
# order is: ([outputs, ...], [inputs, ...])
987+
replicate = ([Replicate(), Replicate()], [Replicate(), None, None, None])
988+
shard_dim0 = ([Shard(0), Shard(0)], [Shard(0), None, None, None])
989+
shard_dim1 = ([Shard(1), Shard(1)], [Shard(1), None, None, None])
990+
acceptable_shardings = [replicate, shard_dim0, shard_dim1]
991+
return acceptable_shardings
992+
993+
978994
if _mxfp8_cuda_kernels_available:
979995
# CUDA kernel for per group blocked layout transform with groups along M
980996
def mx_block_rearrange_2d_M_groups_cuda(

0 commit comments

Comments
 (0)