Skip to content

Commit 0306ca4

Browse files
committed
Update
[ghstack-poisoned]
2 parents 5daac37 + 60b5235 commit 0306ca4

17 files changed

Lines changed: 1547 additions & 234 deletions

File tree

benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
class ExperimentConfig:
3232
input_shape: tuple[int]
3333
scaling_mode: ScaleCalculationMode
34+
scale_block_k: int
3435

3536

3637
@dataclass(frozen=True)
@@ -62,19 +63,24 @@ def get_configs() -> List[ExperimentConfig]:
6263
(32, 8192, 5120),
6364
]
6465
round_modes = [ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL]
66+
scale_block_ks = [1, 32]
6567
configs = []
66-
for shape, scaling_mode in itertools.product(input_shapes, round_modes):
68+
for shape, scaling_mode, scale_block_k in itertools.product(
69+
input_shapes, round_modes, scale_block_ks
70+
):
6771
configs.append(
6872
ExperimentConfig(
6973
input_shape=shape,
7074
scaling_mode=scaling_mode,
75+
scale_block_k=scale_block_k,
7176
)
7277
)
7378
return configs
7479

7580

7681
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7782
block_size = 32
83+
scale_block_k = config.scale_block_k
7884
input_shape = config.input_shape
7985
input_tensor = torch.randn(
8086
*input_shape,
@@ -83,20 +89,37 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8389
)
8490

8591
def using_to_mx(x: torch.Tensor) -> torch.Tensor:
86-
# Reference implementation
87-
s_d1_ref, y_d1_ref = to_mx(
88-
# Transpose (E,N,K) to (E,K,N) so N is final dim,
89-
# since to_mx scales along that dim
90-
x.transpose(-2, -1).contiguous(),
92+
if scale_block_k == 1:
93+
s_ref, y_ref = to_mx(
94+
x.transpose(-2, -1).contiguous(),
95+
elem_dtype=torch.float8_e4m3fn,
96+
block_size=block_size,
97+
)
98+
return y_ref.transpose(-2, -1), s_ref.transpose(-2, -1)
99+
100+
assert scale_block_k == 32
101+
E, N, K = x.shape
102+
x_tiles = (
103+
x.view(E, N // block_size, block_size, K // block_size, block_size)
104+
.permute(0, 1, 3, 2, 4)
105+
.contiguous()
106+
.view(E, N // block_size, K // block_size, block_size * block_size)
107+
)
108+
s_ref, y_tiles_ref = to_mx(
109+
x_tiles,
91110
elem_dtype=torch.float8_e4m3fn,
92-
block_size=block_size,
111+
block_size=block_size * block_size,
93112
)
94-
95-
# Transpose tensors and scales back so we have effectively
96-
# quantized input shape (E, N, K) along N
97-
y_d1_ref = y_d1_ref.transpose(-2, -1)
98-
s_d1_ref = s_d1_ref.transpose(-2, -1)
99-
return y_d1_ref, s_d1_ref
113+
y_ref = (
114+
y_tiles_ref.view(
115+
E, N // block_size, K // block_size, block_size, block_size
116+
)
117+
.permute(0, 1, 3, 2, 4)
118+
.contiguous()
119+
.view(E, N, K)
120+
)
121+
y_ref = y_ref.transpose(-2, -1).contiguous().transpose(-2, -1)
122+
return y_ref, s_ref
100123

101124
# bench to_mx
102125
using_to_mx_c = torch.compile(using_to_mx)
@@ -106,26 +129,33 @@ def using_to_mx(x: torch.Tensor) -> torch.Tensor:
106129
input_tensor,
107130
)
108131

109-
# bench 2d dim1 kernel then transforming to col major
110-
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
111-
scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor)
112-
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
113-
using_cuda_2d_c,
114-
input_tensor,
115-
block_size=block_size,
116-
scaling_mode=config.scaling_mode,
117-
)
132+
if scale_block_k == 1:
133+
# bench 2d dim1 kernel then transforming to col major
134+
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
135+
using_cuda_2d_c(input_tensor)
136+
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
137+
using_cuda_2d_c,
138+
input_tensor,
139+
block_size=block_size,
140+
scaling_mode=config.scaling_mode,
141+
)
142+
else:
143+
time_cuda_2d_us = float("nan")
118144

119145
# bench 3d CuTeDSL kernel
120146
data_cuda_3d, scales_cuda_3d = mxfp8_quantize_cuda_3d(
121147
input_tensor,
122148
block_size=block_size,
149+
scale_block_n=block_size,
150+
scale_block_k=scale_block_k,
123151
scaling_mode=str(config.scaling_mode.value),
124152
)
125153
time_cutedsl_3d_us = benchmark_cuda_function_in_microseconds(
126154
mxfp8_quantize_cuda_3d,
127155
input_tensor,
128156
block_size=block_size,
157+
scale_block_n=block_size,
158+
scale_block_k=scale_block_k,
129159
scaling_mode=str(config.scaling_mode.value),
130160
)
131161

@@ -159,6 +189,7 @@ def print_results(experiments: List[Experiment]):
159189
headers = [
160190
"input_shape",
161191
"scaling_mode",
192+
"scale_block_k",
162193
"cuda_2d_us",
163194
"cutedsl_3d_us",
164195
"to_mx_us",
@@ -172,6 +203,7 @@ def print_results(experiments: List[Experiment]):
172203
[
173204
str(experiment.config.input_shape),
174205
str(experiment.config.scaling_mode),
206+
str(experiment.config.scale_block_k),
175207
experiment.result.cuda_2d_us,
176208
experiment.result.cutedsl_3d_us,
177209
experiment.result.to_mx_us,

benchmarks/prototype/moe_training/mxfp8/roofline_unified.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,11 @@ def benchmark_mxfp8_quantize_cuda_3d(tensor, block_size=32):
495495
"""Benchmark mxfp8_quantize_cuda_3d kernel"""
496496
return benchmark_cuda_function_in_microseconds(
497497
lambda: mxfp8_quantize_cuda_3d(
498-
tensor, block_size=block_size, scaling_mode="rceil"
498+
tensor,
499+
block_size=block_size,
500+
scale_block_n=block_size,
501+
scale_block_k=1,
502+
scaling_mode="rceil",
499503
)
500504
)
501505

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
from collections.abc import Iterable
9+
10+
import torch
11+
import torch.distributed as dist
12+
from torch.distributed._tensor import DTensor
13+
14+
from packaging import version
15+
from torchao.prototype.blockwise_fp8_training.linear import (
16+
Float8BlockwiseLinear,
17+
Float8BlockwiseLinearConfig,
18+
)
19+
from torchao.quantization import quantize_
20+
from torchao.testing.training.dtensor_utils import ToyModel
21+
from torchao.utils import is_sm_at_least_90
22+
23+
24+
def get_blockwise_linear_skip_reason(
25+
*,
26+
triton_module,
27+
min_cuda_devices: int,
28+
) -> str | None:
29+
"""Shared module-level gating for Float8BlockwiseLinear distributed tests.
30+
31+
This is intentionally separate from the lower-level kernel test gating because
32+
the module swap currently requires SM90+ and the newer scaled_mm/Triton path.
33+
"""
34+
if not torch.cuda.is_available():
35+
return "CUDA not available"
36+
if torch.cuda.device_count() < min_cuda_devices:
37+
return f"Need at least {min_cuda_devices} CUDA devices"
38+
if not is_sm_at_least_90():
39+
return "Float8BlockwiseLinear currently requires CUDA SM90+"
40+
if version.parse(triton_module.__version__) < version.parse("3.3.0"):
41+
return "Triton version < 3.3.0"
42+
return None
43+
44+
45+
def full_tensor(tensor: torch.Tensor) -> torch.Tensor:
46+
"""Materialize a DTensor for parity checks, otherwise return the tensor as-is."""
47+
return tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
48+
49+
50+
def assert_close(
51+
actual: torch.Tensor,
52+
expected: torch.Tensor,
53+
*,
54+
atol: float = 2e-2,
55+
rtol: float = 2e-2,
56+
) -> None:
57+
"""Compare eager tensors and DTensors using a common float32 tolerance path."""
58+
torch.testing.assert_close(
59+
full_tensor(actual).float(),
60+
full_tensor(expected).float(),
61+
atol=atol,
62+
rtol=rtol,
63+
)
64+
65+
66+
def set_blockwise_linear_use_triton(
67+
model: torch.nn.Module,
68+
use_triton: bool,
69+
) -> None:
70+
converted = 0
71+
for module in model.modules():
72+
if isinstance(module, Float8BlockwiseLinear):
73+
module.use_triton = use_triton
74+
converted += 1
75+
if converted == 0:
76+
raise AssertionError("Expected at least one Float8BlockwiseLinear module")
77+
78+
79+
def broadcast_module(module: torch.nn.Module) -> None:
80+
for param in module.parameters():
81+
dist.broadcast(param, src=0)
82+
83+
84+
def init_toy_model(
85+
*,
86+
size: int = 128,
87+
seed: int = 42,
88+
device: str | torch.device = "cuda",
89+
broadcast_weights: bool = False,
90+
) -> torch.nn.Module:
91+
torch.manual_seed(seed)
92+
model = ToyModel(size).to(device=device, dtype=torch.bfloat16)
93+
if broadcast_weights:
94+
broadcast_module(model)
95+
return model
96+
97+
98+
def make_quantized_toy_model_pair(
99+
*,
100+
size: int = 128,
101+
seed: int = 42,
102+
device: str | torch.device = "cuda",
103+
use_triton: bool,
104+
broadcast_weights: bool = False,
105+
) -> tuple[torch.nn.Module, torch.nn.Module]:
106+
ref_model = init_toy_model(
107+
size=size,
108+
seed=seed,
109+
device=device,
110+
broadcast_weights=broadcast_weights,
111+
)
112+
dist_model = copy.deepcopy(ref_model)
113+
for model in (ref_model, dist_model):
114+
quantize_(model, Float8BlockwiseLinearConfig())
115+
set_blockwise_linear_use_triton(model, use_triton)
116+
return ref_model, dist_model
117+
118+
119+
def get_replicated_local_batch(
120+
*,
121+
replica_count: int,
122+
replica_index: int,
123+
iter_idx: int,
124+
size: int = 128,
125+
device: str | torch.device = "cuda",
126+
) -> tuple[torch.Tensor, torch.Tensor]:
127+
"""Build one global batch and hand each replica its deterministic local slice.
128+
129+
TP peers should see the same sample, while different data-parallel replicas
130+
should see different samples. Broadcasting from rank 0 keeps the reference
131+
and distributed models aligned across all ranks.
132+
"""
133+
torch.manual_seed(100 + iter_idx)
134+
global_input = torch.randn(
135+
replica_count,
136+
1,
137+
size,
138+
size,
139+
device=device,
140+
dtype=torch.bfloat16,
141+
)
142+
global_target = torch.randn_like(global_input)
143+
dist.broadcast(global_input, src=0)
144+
dist.broadcast(global_target, src=0)
145+
return (
146+
global_input[replica_index].contiguous(),
147+
global_target[replica_index].contiguous(),
148+
)
149+
150+
151+
def assert_parameters_are_dtensors(parameters: Iterable[torch.Tensor]) -> None:
152+
for param in parameters:
153+
assert isinstance(param, DTensor)
154+
155+
156+
def allreduce_reference_grads(
157+
model: torch.nn.Module,
158+
*,
159+
world_size: int,
160+
group=None,
161+
) -> None:
162+
for param in model.parameters():
163+
assert param.grad is not None
164+
dist.all_reduce(param.grad, group=group)
165+
param.grad.div_(world_size)
166+
167+
168+
def assert_dtensor_parameter_grads_match(
169+
ref_parameters: Iterable[torch.nn.Parameter],
170+
dist_parameters: Iterable[torch.nn.Parameter],
171+
) -> None:
172+
for ref_param, dist_param in zip(ref_parameters, dist_parameters, strict=True):
173+
assert ref_param.grad is not None
174+
assert dist_param.grad is not None
175+
assert isinstance(dist_param, DTensor)
176+
assert isinstance(dist_param.grad, DTensor)
177+
assert_close(dist_param.grad, ref_param.grad)
178+
179+
180+
def assert_dtensor_parameter_values_match(
181+
ref_parameters: Iterable[torch.nn.Parameter],
182+
dist_parameters: Iterable[torch.nn.Parameter],
183+
) -> None:
184+
for ref_param, dist_param in zip(ref_parameters, dist_parameters, strict=True):
185+
assert isinstance(dist_param, DTensor)
186+
assert_close(dist_param, ref_param)

0 commit comments

Comments
 (0)