Skip to content

Commit a9f24af

Browse files
Support 32x32 scaling for weights in MXFP8 weight quantization kernel (#4254)
1 parent 9058b58 commit a9f24af

7 files changed

Lines changed: 623 additions & 144 deletions

File tree

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
class ExperimentConfig:
3232
input_shape: tuple[int]
3333
scaling_mode: ScaleCalculationMode
34+
scale_block_k: int
3435

3536

3637
@dataclass(frozen=True)
@@ -62,19 +63,24 @@ def get_configs() -> List[ExperimentConfig]:
6263
(32, 8192, 5120),
6364
]
6465
round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL]
66+
scale_block_ks = [1, 32]
6567
configs = []
66-
for shape, scaling_mode in itertools.product(input_shapes, round_modes):
68+
for shape, scaling_mode, scale_block_k in itertools.product(
69+
input_shapes, round_modes, scale_block_ks
70+
):
6771
configs.append(
6872
ExperimentConfig(
6973
input_shape=shape,
7074
scaling_mode=scaling_mode,
75+
scale_block_k=scale_block_k,
7176
)
7277
)
7378
return configs
7479

7580

7681
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7782
block_size = 32
83+
scale_block_k = config.scale_block_k
7884
input_shape = config.input_shape
7985
input_tensor = torch.randn(
8086
*input_shape,
@@ -83,20 +89,37 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8389
)
8490

8591
def using_to_mx(x: torch.Tensor) -> torch.Tensor:
86-
# Reference implementation
87-
s_d1_ref, y_d1_ref = to_mx(
88-
# Transpose (E,N,K) to (E,K,N) so N is final dim,
89-
# since to_mx scales along that dim
90-
x.transpose(-2, -1).contiguous(),
92+
if scale_block_k == 1:
93+
s_ref, y_ref = to_mx(
94+
x.transpose(-2, -1).contiguous(),
95+
elem_dtype=torch.float8_e4m3fn,
96+
block_size=block_size,
97+
)
98+
return y_ref.transpose(-2, -1), s_ref.transpose(-2, -1)
99+
100+
assert scale_block_k == 32
101+
E, N, K = x.shape
102+
x_tiles = (
103+
x.view(E, N // block_size, block_size, K // block_size, block_size)
104+
.permute(0, 1, 3, 2, 4)
105+
.contiguous()
106+
.view(E, N // block_size, K // block_size, block_size * block_size)
107+
)
108+
s_ref, y_tiles_ref = to_mx(
109+
x_tiles,
91110
elem_dtype=torch.float8_e4m3fn,
92-
block_size=block_size,
111+
block_size=block_size * block_size,
93112
)
94-
95-
# Transpose tensors and scales back so we have effectively
96-
# quantized input shape (E, N, K) along N
97-
y_d1_ref = y_d1_ref.transpose(-2, -1)
98-
s_d1_ref = s_d1_ref.transpose(-2, -1)
99-
return y_d1_ref, s_d1_ref
113+
y_ref = (
114+
y_tiles_ref.view(
115+
E, N // block_size, K // block_size, block_size, block_size
116+
)
117+
.permute(0, 1, 3, 2, 4)
118+
.contiguous()
119+
.view(E, N, K)
120+
)
121+
y_ref = y_ref.transpose(-2, -1).contiguous().transpose(-2, -1)
122+
return y_ref, s_ref
100123

101124
# bench to_mx
102125
using_to_mx_c = torch.compile(using_to_mx)
@@ -106,26 +129,33 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
106129
input_tensor,
107130
)
108131

109-
# bench 2d dim1 kernel then transforming to col major
110-
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
111-
scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor)
112-
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
113-
using_cuda_2d_c,
114-
input_tensor,
115-
block_size=block_size,
116-
scaling_mode=config.scaling_mode,
117-
)
132+
if scale_block_k == 1:
133+
# bench 2d dim1 kernel then transforming to col major
134+
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
135+
using_cuda_2d_c(input_tensor)
136+
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
137+
using_cuda_2d_c,
138+
input_tensor,
139+
block_size=block_size,
140+
scaling_mode=config.scaling_mode,
141+
)
142+
else:
143+
time_cuda_2d_us = float("nan")
118144

119145
# bench 3d CuTeDSL kernel
120146
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(
121147
input_tensor,
122148
block_size=block_size,
149+
scale_block_n=block_size,
150+
scale_block_k=scale_block_k,
123151
scaling_mode=str(config.scaling_mode.value),
124152
)
125153
time_cutedsl_3d_us = benchmark_cuda_function_in_microseconds(
126154
mxfp8_quantize_cuda_3d,
127155
input_tensor,
128156
block_size=block_size,
157+
scale_block_n=block_size,
158+
scale_block_k=scale_block_k,
129159
scaling_mode=str(config.scaling_mode.value),
130160
)
131161

@@ -159,6 +189,7 @@ def print_results(experiments: List[Experiment]):
159189
headers = [
160190
"input_shape",
161191
"scaling_mode",
192+
"scale_block_k",
162193
"cuda_2d_us",
163194
"cutedsl_3d_us",
164195
"to_mx_us",
@@ -172,6 +203,7 @@ def print_results(experiments: List[Experiment]):
172203
[
173204
str(experiment.config.input_shape),
174205
str(experiment.config.scaling_mode),
206+
str(experiment.config.scale_block_k),
175207
experiment.result.cuda_2d_us,
176208
experiment.result.cutedsl_3d_us,
177209
experiment.result.to_mx_us,

benchmarks/prototype/moe_training/mxfp8/roofline_unified.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ def benchmark_mxfp8_quantize_cuda_3d(tensor, block_size=32):
495495
"""Benchmark mxfp8_quantize_cuda_3d kernel"""
496496
return benchmark_cuda_function_in_microseconds(
497497
lambda: mxfp8_quantize_cuda_3d(
498-
tensor, block_size=block_size, scaling_mode="rceil"
498+
tensor,
499+
block_size=block_size,
500+
scale_block_n=block_size,
501+
scale_block_k=1,
502+
scaling_mode="rceil",
499503
)
500504
)
501505

test/prototype/moe_training/test_kernels.py

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,12 @@ def test_triton_mx_block_rearrange_2d_K_groups(
392392
@pytest.mark.parametrize(
393393
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
394394
)
395-
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
395+
@pytest.mark.parametrize(
396+
"scale_block_k",
397+
(1, 32),
398+
ids=("32x1", "32x32"),
399+
)
400+
def test_cuda_mx_3d_cutedsl_numerics(E, N, K, input_dtype, scaling_mode, scale_block_k):
396401
if not _mxfp8_cutedsl_kernels_available:
397402
pytest.skip("mxfp8_quantize_3d is unavailable")
398403

@@ -408,37 +413,97 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
408413
.contiguous()
409414
)
410415

411-
# Reference implementation
412-
s_d1_ref, y_d1_ref = to_mx(
413-
# Transpose so N is final dim, since to_mx scales along that dim
414-
x.transpose(-2, -1).contiguous(),
415-
elem_dtype=torch.float8_e4m3fn,
416-
block_size=block_size,
417-
scaling_mode=scaling_mode,
418-
)
416+
if scale_block_k == 1:
417+
s_ref, y_ref = to_mx(
418+
x.transpose(-2, -1).contiguous(),
419+
elem_dtype=torch.float8_e4m3fn,
420+
block_size=block_size,
421+
scaling_mode=scaling_mode,
422+
)
423+
y_ref = y_ref.transpose(-2, -1)
424+
s_ref = s_ref.transpose(-2, -1)
425+
s_rows, s_cols = K, N // block_size
426+
undo_scale = (
427+
lambda scale: from_blocked(scale, s_rows, s_cols)
428+
.transpose(-2, -1)
429+
.contiguous()
430+
)
431+
else:
432+
x_tiles = (
433+
x.view(E, N // block_size, block_size, K // block_size, block_size)
434+
.permute(0, 1, 3, 2, 4)
435+
.contiguous()
436+
.view(E, N // block_size, K // block_size, block_size * block_size)
437+
)
438+
s_ref, y_tiles_ref = to_mx(
439+
x_tiles,
440+
elem_dtype=torch.float8_e4m3fn,
441+
block_size=block_size * block_size,
442+
scaling_mode=scaling_mode,
443+
)
444+
s_ref = s_ref.squeeze(-1)
445+
y_ref = (
446+
y_tiles_ref.view(
447+
E, N // block_size, K // block_size, block_size, block_size
448+
)
449+
.permute(0, 1, 3, 2, 4)
450+
.contiguous()
451+
.view(E, N, K)
452+
)
453+
y_ref = y_ref.transpose(-2, -1).contiguous().transpose(-2, -1)
454+
s_rows, s_cols = K, N // block_size
455+
undo_scale = lambda scale: from_blocked(scale, s_rows, s_cols)[
456+
::block_size
457+
].transpose(-2, -1)
419458

420-
# Transpose tensors and scales back so we have effectively
421-
# quantized input shape (E, N, K) along N
422-
y_d1_ref = y_d1_ref.transpose(-2, -1)
423-
s_d1_ref = s_d1_ref.transpose(-2, -1)
424-
y_d1, s_d1 = mxfp8_quantize_cuda_3d(
459+
y, s = mxfp8_quantize_cuda_3d(
425460
x,
426461
block_size=block_size,
462+
scale_block_n=block_size,
463+
scale_block_k=scale_block_k,
427464
scaling_mode=scaling_mode_str,
465+
blocked_scale_output=True,
466+
)
467+
if scale_block_k == 32:
468+
s_blocked_full = (
469+
torch.stack(
470+
[
471+
from_blocked(s[e], s_rows, s_cols).view(torch.uint8)
472+
for e in range(E)
473+
],
474+
dim=0,
475+
)
476+
.view(torch.float8_e8m0fnu)
477+
.to(s_ref.dtype)
478+
)
479+
s_ref_replicated = s_ref.transpose(-2, -1).repeat_interleave(block_size, dim=1)
480+
torch.testing.assert_close(s_blocked_full, s_ref_replicated, rtol=0, atol=0)
481+
s = (
482+
torch.stack([undo_scale(s[e]).view(torch.uint8) for e in range(E)], dim=0)
483+
.view(torch.float8_e8m0fnu)
484+
.to(s_ref.dtype)
428485
)
429-
s_d1 = torch.stack(
430-
[
431-
from_blocked(s_d1[e], K, N // block_size).transpose(-2, -1).contiguous()
432-
for e in range(E)
433-
],
434-
dim=0,
435-
).to(s_d1_ref.dtype)
436486
# Check scales
437-
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
487+
torch.testing.assert_close(s, s_ref, rtol=0, atol=0)
438488

439489
# Check quantized values
440-
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
441-
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
490+
torch.testing.assert_close(y, y_ref, rtol=0, atol=0)
491+
assert y.stride() == y_ref.stride(), "quantized tensor strides do not match"
492+
493+
y_unblocked, s_unblocked = mxfp8_quantize_cuda_3d(
494+
x,
495+
block_size=block_size,
496+
scale_block_n=block_size,
497+
scale_block_k=scale_block_k,
498+
scaling_mode=scaling_mode_str,
499+
blocked_scale_output=False,
500+
)
501+
s_unblocked = s_unblocked.to(s_ref.dtype)
502+
torch.testing.assert_close(s_unblocked, s_ref, rtol=0, atol=0)
503+
torch.testing.assert_close(y_unblocked, y_ref, rtol=0, atol=0)
504+
assert y_unblocked.stride() == y_ref.stride(), (
505+
"unblocked quantized tensor strides do not match"
506+
)
442507

443508

444509
@pytest.mark.skipif(

0 commit comments

Comments
 (0)