|
8 | 8 |
|
9 | 9 | import torch |
10 | 10 | from torch import Tensor |
| 11 | +from torch.distributed.tensor import Replicate, Shard |
| 12 | +from torch.distributed.tensor.experimental import register_sharding |
11 | 13 | from torch.utils._triton import has_triton |
12 | 14 |
|
13 | 15 | from torchao.prototype.moe_training.kernels.mxfp8.cute_utils import ( |
@@ -975,6 +977,20 @@ def _fake_mxfp8_quantize_2d_cutedsl_custom_op( |
975 | 977 | return q_data, scales |
976 | 978 |
|
977 | 979 |
|
| 980 | +if _mxfp8_cutedsl_kernels_available: |
| 981 | + |
| 982 | + @register_sharding(torch.ops.torchao.mxfp8_quantize_2d_cutedsl.default) |
| 983 | + def custom_sharding_for_cutedsl_mxfp8_dim0_kernel( |
| 984 | + x, block_size=32, scaling_mode: str = "rceil", stage_count: int = 2 |
| 985 | + ): |
| 986 | + # order is: ([outputs, ...], [inputs, ...]) |
| 987 | + replicate = ([Replicate(), Replicate()], [Replicate(), None, None, None]) |
| 988 | + shard_dim0 = ([Shard(0), Shard(0)], [Shard(0), None, None, None]) |
| 989 | + shard_dim1 = ([Shard(1), Shard(1)], [Shard(1), None, None, None]) |
| 990 | + acceptable_shardings = [replicate, shard_dim0, shard_dim1] |
| 991 | + return acceptable_shardings |
| 992 | + |
| 993 | + |
978 | 994 | if _mxfp8_cuda_kernels_available: |
979 | 995 | # CUDA kernel for per group blocked layout transform with groups along M |
980 | 996 | def mx_block_rearrange_2d_M_groups_cuda( |
|
0 commit comments