Skip to content

Commit 4c6dbea

Browse files
[mxfp8 training] on-device validation of group sizes in cutedsl quant kernels (#4253)
1 parent 10af862 commit 4c6dbea

7 files changed

Lines changed: 305 additions & 72 deletions

File tree

benchmarks/prototype/moe_training/mxfp8/bench_cutedsl_quantize_2d_1x32.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
9494

9595
# Generate jagged offsets with multiples of 128
9696
# TODO: we use multiple of 128 here to avoid per-group padding requirement in blocked scales layout, which cutedsl doesn't support yet.
97-
group_end_offsets = generate_jagged_offs(
98-
num_groups, M, multiple_of=128, device=device
99-
)
97+
offs = generate_jagged_offs(num_groups, M, multiple_of=128, device=device)
10098

10199
# Benchmark 1: CuTeDSL kernel with blocked scale output
102100
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d_1x32(
@@ -127,11 +125,11 @@ def triton_plus_rearrange(x, group_offs):
127125
)
128126
return data, scales_blocked
129127

130-
data_triton, scales_triton = triton_plus_rearrange(input_tensor, group_end_offsets)
128+
data_triton, scales_triton = triton_plus_rearrange(input_tensor, offs)
131129
triton_plus_rearrange_time_us = benchmark_cuda_function_in_microseconds(
132130
triton_plus_rearrange,
133131
input_tensor,
134-
group_end_offsets,
132+
offs,
135133
)
136134

137135
# Memory bandwidth calculations

benchmarks/prototype/moe_training/mxfp8/bench_cutedsl_quantize_2d_32x1.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
9494

9595
# Generate jagged offsets with multiples of 128 for K dimension
9696
# TODO: we use multiple of 128 here to avoid per-group padding requirement in blocked scales layout, which cutedsl doesn't support yet.
97-
group_end_offsets = generate_jagged_offs(
98-
num_groups, K, multiple_of=128, device=device
99-
)
97+
offs = generate_jagged_offs(num_groups, K, multiple_of=128, device=device)
10098

10199
# Benchmark 1: CuTeDSL kernel with blocked scale output
102100
data_cutedsl, scales_cutedsl = mxfp8_quantize_cutedsl_2d_32x1(
@@ -128,11 +126,11 @@ def cuda_plus_rearrange(x, group_offs):
128126
)
129127
return output_colwise, scales_blocked
130128

131-
data_cuda, scales_cuda = cuda_plus_rearrange(input_tensor, group_end_offsets)
129+
data_cuda, scales_cuda = cuda_plus_rearrange(input_tensor, offs)
132130
cuda_plus_rearrange_time_us = benchmark_cuda_function_in_microseconds(
133131
cuda_plus_rearrange,
134132
input_tensor,
135-
group_end_offsets,
133+
offs,
136134
)
137135

138136
# Memory bandwidth calculations

test/prototype/moe_training/test_kernels.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def _is_sm_10x() -> bool:
6565
from torchao.prototype.mx_formats.utils import from_blocked
6666
from torchao.testing.utils import skip_if_rocm
6767

68+
from .testing_utils import generate_split_sizes
69+
6870

6971
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
7072
def test_row_major_with_jagged_rowwise_scales(round_scales_to_power_of_2: bool):
@@ -636,8 +638,8 @@ def test_cuda_fused_unpad_token_groups(
636638
)
637639

638640
# First pad the tokens to create padded inputs
639-
padded_tokens, padded_group_start_offsets, padded_group_end_offsets = (
640-
torch_pad_token_groups(inputs, group_offsets, alignment_size)
641+
padded_tokens, padded_group_start_offsets, padded_offsets = torch_pad_token_groups(
642+
inputs, group_offsets, alignment_size
641643
)
642644

643645
# Get reference output using torch implementation
@@ -704,3 +706,115 @@ def test_triton_fp8_rowwise_2d_scale_and_cast(
704706
assert ref_scales.shape == triton_scales.shape, "scale shapes not equal"
705707
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
706708
assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal"
709+
710+
711+
@pytest.mark.skipif(
712+
not _is_sm_10x(),
713+
reason="MXFP8 requires CUDA SM 10.x",
714+
)
715+
@pytest.mark.skipif(
716+
not _mxfp8_cutedsl_kernels_available,
717+
reason="MXFP8 cutedsl kernels not available",
718+
)
719+
@skip_if_rocm("ROCm enablement in progress")
720+
def test_cutedsl_1x32_group_validation_error():
721+
"""Test that 1x32 CuTeDSL kernel raises error for non-128-multiple group sizes."""
722+
device = "cuda"
723+
M, K = 512, 1024
724+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
725+
num_groups = 4
726+
727+
# Generate group sizes and force at least one to be invalid
728+
group_sizes = generate_split_sizes(num_groups, M, device)
729+
if group_sizes[0] % 128 == 0:
730+
group_sizes[0] = group_sizes[0] - 1 # Make it not a multiple of 128
731+
group_sizes[1] = group_sizes[1] + 1 # Compensate to maintain total sum
732+
733+
invalid_offsets = torch.cumsum(group_sizes, dim=0, dtype=torch.int32)
734+
735+
# Test should raise RuntimeError due to device assertion failure with specific message
736+
with pytest.raises(
737+
RuntimeError,
738+
match=r"unspecified launch failure",
739+
):
740+
_ = mxfp8_quantize_2d_1x32_cutedsl(
741+
x, block_size=32, scaling_mode="rceil", offs=invalid_offsets
742+
)
743+
# Force synchronization to ensure device error propagates
744+
torch.cuda.synchronize()
745+
746+
747+
@pytest.mark.skipif(
748+
not _is_sm_10x(),
749+
reason="MXFP8 requires CUDA SM 10.x",
750+
)
751+
@pytest.mark.skipif(
752+
not _mxfp8_cutedsl_kernels_available,
753+
reason="MXFP8 cutedsl kernels not available",
754+
)
755+
@skip_if_rocm("ROCm enablement in progress")
756+
def test_cutedsl_32x1_group_validation_error():
757+
"""Test that 32x1 CuTeDSL kernel raises error for non-128-multiple group sizes."""
758+
device = "cuda"
759+
M, K = 512, 1024
760+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
761+
num_groups = 4
762+
763+
# Generate group sizes and force at least one to be invalid
764+
group_sizes = generate_split_sizes(num_groups, M, device)
765+
if group_sizes[0] % 128 == 0:
766+
group_sizes[0] = group_sizes[0] - 1 # Make it not a multiple of 128
767+
group_sizes[1] = group_sizes[1] + 1 # Compensate to maintain total sum
768+
769+
invalid_offsets = torch.cumsum(group_sizes, dim=0, dtype=torch.int32)
770+
771+
# Test should raise RuntimeError due to device assertion failure with specific message
772+
with pytest.raises(RuntimeError, match=r"unspecified launch failure"):
773+
_ = mxfp8_quantize_2d_32x1_cutedsl(
774+
x, block_size=32, scaling_mode="rceil", offs=invalid_offsets
775+
)
776+
# Force synchronization to ensure device error propagates
777+
torch.cuda.synchronize()
778+
779+
780+
@pytest.mark.skipif(
781+
not _is_sm_10x(),
782+
reason="MXFP8 requires CUDA SM 10.x",
783+
)
784+
@pytest.mark.skipif(
785+
not _mxfp8_cutedsl_kernels_available,
786+
reason="MXFP8 cutedsl kernels not available",
787+
)
788+
@skip_if_rocm("ROCm enablement in progress")
789+
def test_cutedsl_kernels_work_with_valid_128_multiple_groups():
790+
"""Test that both CuTeDSL kernels work correctly with valid 128-multiple group sizes."""
791+
device = "cuda"
792+
M, K = 512, 1024
793+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
794+
795+
# Create valid group offsets (all group sizes are multiples of 128)
796+
valid_group_sizes = [128, 256, 128] # All multiples of 128
797+
valid_offsets = torch.cumsum(
798+
torch.tensor(valid_group_sizes, dtype=torch.int32), dim=0
799+
).to(device)
800+
801+
# Verify all group sizes are multiples of 128
802+
group_sizes = torch.diff(
803+
torch.cat([torch.zeros(1, device=device, dtype=torch.int32), valid_offsets])
804+
)
805+
assert torch.all(group_sizes % 128 == 0), (
806+
"Test setup failed: not all group sizes are multiples of 128"
807+
)
808+
809+
# Both kernels should work without error
810+
y_1x32, s_1x32 = mxfp8_quantize_2d_1x32_cutedsl(
811+
x, block_size=32, scaling_mode="rceil", offs=valid_offsets
812+
)
813+
814+
y_32x1, s_32x1 = mxfp8_quantize_2d_32x1_cutedsl(
815+
x, block_size=32, scaling_mode="rceil", offs=valid_offsets
816+
)
817+
818+
# Basic output validation
819+
assert y_1x32.shape == (M, K)
820+
assert y_32x1.shape == (M, K)

torchao/prototype/moe_training/kernels/mxfp8/cute_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,26 @@ def load_vals_chunk_tail(
225225
else:
226226
vals_chunk[j] = cutlass.Float32(0.0)
227227
return vals_chunk
228+
229+
@cute.jit
230+
def validate_group_sizes(offs: cute.Tensor):
231+
# Only first thread validates to avoid redundant work
232+
num_groups = offs.shape[0]
233+
234+
# Validate first group (from 0 to offs[0])
235+
if num_groups > 0:
236+
first_group_size = offs[0]
237+
cute.testing.assert_(
238+
first_group_size % 128 == 0,
239+
"Group sizes must be multiples of 128",
240+
)
241+
242+
# Validate subsequent groups
243+
for i in range(1, num_groups):
244+
prev_end = offs[i - 1]
245+
curr_end = offs[i]
246+
group_size = curr_end - prev_end
247+
cute.testing.assert_(
248+
group_size % 128 == 0,
249+
"Group sizes must be multiples of 128",
250+
)

torchao/prototype/moe_training/kernels/mxfp8/cutedsl_quantize_2d_1x32.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import functools
8-
from typing import Tuple
8+
from typing import Optional, Tuple
99

1010
import torch
1111

@@ -17,6 +17,7 @@
1717
compute_scale_from_amax,
1818
load_vals_chunk_full,
1919
load_vals_chunk_tail,
20+
validate_group_sizes,
2021
)
2122

2223

@@ -83,6 +84,7 @@ def _compile_mxfp8_quantize_2d_cutedsl(
8384
k_tiles_per_cta: int,
8485
is_full_k_tiles: bool,
8586
blocked_scale_output: bool,
87+
offs: Optional[torch.Tensor] = None,
8688
):
8789
"""Compile the 2D MXFP8 quantization kernel using CuTeDSL.
8890
@@ -531,6 +533,7 @@ def kernel(
531533
m_cta_tiles: cutlass.Int64,
532534
k_cta_tiles: cutlass.Int64,
533535
blocked_scale_layout: cute.Layout,
536+
offs: Optional[cute.Tensor],
534537
SCALE_DIM_K: cutlass.Constexpr[int],
535538
USE_RCEIL: cutlass.Constexpr[bool],
536539
IS_FULL_K_TILES: cutlass.Constexpr[bool],
@@ -560,6 +563,7 @@ def kernel(
560563
m_cta_tiles: Number of tiles in M dimension
561564
k_cta_tiles: Number of tile groups in K dimension
562565
blocked_scale_layout: Layout for blocked scale output
566+
offs: Tensor of group end offsets for validation
563567
SCALE_DIM_K: Block size (32)
564568
USE_RCEIL: Whether using RCEIL mode
565569
IS_FULL_K_TILES: Whether K is perfectly tiled
@@ -575,6 +579,11 @@ def kernel(
575579
warp_idx = cute.arch.make_warp_uniform(warp_idx)
576580
bidx, bidy, _ = cute.arch.block_idx()
577581

582+
# Validate group sizes are multiples of 128 if offs is provided
583+
if cutlass.const_expr(offs is not None):
584+
if tidx == 0:
585+
validate_group_sizes(offs)
586+
578587
smem_allocator = utils.SmemAllocator()
579588
storage = smem_allocator.allocate(SharedStorage)
580589
# The tuned contract keeps STAGE_COUNT <= 2.
@@ -812,6 +821,7 @@ def __call__(
812821
m_cta_tiles: cutlass.Int64,
813822
k_cta_tiles: cutlass.Int64,
814823
stream: cuda.CUstream,
824+
offs: Optional[cute.Tensor],
815825
):
816826
"""Kernel launcher that sets up TMA descriptors and blocked scale layout.
817827
@@ -825,6 +835,7 @@ def __call__(
825835
m_cta_tiles: Number of tiles in M dimension
826836
k_cta_tiles: Number of tile groups in K dimension
827837
stream: CUDA stream
838+
offs: Tensor of group end offsets for validation (group sizes must be multiples of 128)
828839
829840
Storage locations:
830841
All tensors in global memory
@@ -874,6 +885,7 @@ def __call__(
874885
m_cta_tiles,
875886
k_cta_tiles,
876887
blocked_scale_layout,
888+
offs,
877889
SCALE_DIM_K=SCALE_DIM_K_VALUE,
878890
USE_RCEIL=(scaling_mode == "rceil"),
879891
IS_FULL_K_TILES=IS_FULL_K_TILES_VALUE,
@@ -923,6 +935,21 @@ def __call__(
923935
)
924936
fake_stream = make_fake_stream()
925937

938+
if offs is not None:
939+
offs_stride = cute.sym_int()
940+
fake_offs = make_fake_tensor(
941+
cutlass.Int32,
942+
(cute.sym_int(),),
943+
stride=(offs_stride,),
944+
)
945+
else:
946+
fake_offs = None
947+
948+
compile_options = (
949+
"--enable-tvm-ffi"
950+
if fake_offs is None
951+
else "--enable-tvm-ffi --enable-assertions"
952+
)
926953
return cute.compile(
927954
kernel,
928955
inp_mk=fake_inp,
@@ -934,7 +961,8 @@ def __call__(
934961
m_cta_tiles=1,
935962
k_cta_tiles=1,
936963
stream=fake_stream,
937-
options="--enable-tvm-ffi",
964+
offs=fake_offs,
965+
options=compile_options,
938966
)
939967

940968

@@ -944,6 +972,7 @@ def mxfp8_quantize_cutedsl_2d_1x32(
944972
scaling_mode: str = "rceil",
945973
stage_count: int = 2,
946974
blocked_scale_output: bool = False,
975+
offs: Optional[torch.Tensor] = None,
947976
) -> Tuple[torch.Tensor, torch.Tensor]:
948977
"""
949978
Quantize a 2D tensor to MXFP8 format using CuTe DSL kernel.
@@ -956,6 +985,7 @@ def mxfp8_quantize_cutedsl_2d_1x32(
956985
scaling_mode: Scaling mode ("floor" or "rceil")
957986
stage_count: Number of pipeline stages (1 or 2)
958987
blocked_scale_output: Whether to output scales in blocked layout
988+
offs: Optional tensor of group end offsets for validation (must have group sizes as multiples of 128)
959989
960990
Returns:
961991
q_data: Quantized data in row-major layout with shape (M, K)
@@ -970,6 +1000,11 @@ def mxfp8_quantize_cutedsl_2d_1x32(
9701000
M, K = x.shape
9711001
assert K % block_size == 0, "K must be divisible by block_size"
9721002

1003+
if offs is not None:
1004+
assert offs.is_cuda, "offs tensor must be CUDA"
1005+
assert offs.dtype == torch.int32, "offs must be int32 tensor"
1006+
assert offs.dim() == 1, "offs must be 1D tensor"
1007+
9731008
_, config = _select_cutedsl_config(x.dtype, scaling_mode)
9741009
compute_warps, tile_m, tile_k, k_tiles_per_cta = config
9751010
# B200 sweeps over representative shapes showed no
@@ -1019,6 +1054,7 @@ def mxfp8_quantize_cutedsl_2d_1x32(
10191054
k_tiles_per_cta,
10201055
is_full_k_tiles,
10211056
blocked_scale_output,
1057+
offs,
10221058
)
10231059

10241060
import cuda.bindings.driver as cuda
@@ -1037,6 +1073,7 @@ def mxfp8_quantize_cutedsl_2d_1x32(
10371073
int(m_cta_tiles),
10381074
int(k_cta_tiles),
10391075
stream,
1076+
offs,
10401077
)
10411078

10421079
return q_data, scales_u8.view(torch.float8_e8m0fnu)

0 commit comments

Comments
 (0)