Skip to content

Commit 4e18d87

Browse files
rocm: scaled_grouped_mm support gfx942 fp8 data type (#3955)
1 parent 8d65522 commit 4e18d87

8 files changed

Lines changed: 82 additions & 44 deletions

File tree

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
MXFP8GroupedMMRecipe,
2828
)
2929
from torchao.prototype.moe_training.utils import generate_jagged_offs
30+
from torchao.utils import is_MI300, is_MI350, is_ROCM
3031

3132
device = torch.device("cuda")
3233

@@ -260,14 +261,19 @@ def main(args: argparse.Namespace):
260261
configs = get_configs()
261262
results = []
262263
for config in tqdm(configs):
263-
if (
264-
config.recipe == FP8GroupedMMRecipe.FP8_ROWWISE
265-
and torch.cuda.get_device_capability() != (9, 0)
266-
):
267-
logging.warning(
268-
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
269-
)
270-
continue
264+
if config.recipe == FP8GroupedMMRecipe.FP8_ROWWISE:
265+
if is_ROCM():
266+
if not (is_MI300() or is_MI350()):
267+
logging.warning(
268+
"Skipping FP8 rowwise benchmarks, requires MI300 or MI350 on ROCm"
269+
)
270+
continue
271+
else:
272+
if torch.cuda.get_device_capability() != (9, 0):
273+
logging.warning(
274+
f"Skipping FP8 rowwise benchmarks, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
275+
)
276+
continue
271277

272278
elif config.recipe in (
273279
MXFP8GroupedMMRecipe.MXFP8_RCEIL,

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from torchao.float8.float8_training_tensor import LinearMMConfig
3131
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
3232
from torchao.prototype.moe_training.config import (
33+
FP8GroupedMMConfig,
34+
FP8GroupedMMRecipe,
3335
MXFP8GroupedMMConfig,
3436
MXFP8GroupedMMRecipe,
3537
)
@@ -47,6 +49,7 @@
4749
from torchao.prototype.mx_formats.mx_tensor import to_mx
4850
from torchao.quantization.quantize_.common import KernelPreference
4951
from torchao.testing.utils import skip_if_rocm
52+
from torchao.utils import is_MI300, is_MI350, is_ROCM
5053

5154
# Needed since changing args to function causes recompiles
5255
torch._dynamo.config.cache_size_limit = 1000
@@ -56,14 +59,18 @@
5659
True,
5760
reason="Skipping FP8 rowwise test pending fix for https://github.com/pytorch/ao/issues/3788",
5861
)
59-
@skip_if_rocm("ROCm not supported")
6062
@pytest.mark.parametrize("m", [4096])
6163
@pytest.mark.parametrize("n", [8192])
6264
@pytest.mark.parametrize("k", [5120])
6365
@pytest.mark.parametrize("n_groups", [1, 2, 4, 8])
6466
def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
65-
if not is_sm_version(9, 0):
66-
pytest.skip("Skipping FP8 rowwise test, requires sm90")
67+
if is_ROCM():
68+
if not (is_MI300() or is_MI350()):
69+
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")
70+
else:
71+
if not is_sm_version(9, 0):
72+
pytest.skip("FP8 rowwise test requires SM 9.0 on CUDA")
73+
6774
out_dtype = torch.bfloat16
6875
device = "cuda"
6976
a = torch.randn(
@@ -86,7 +93,7 @@ def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
8693
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)
8794

8895
# Compute output.
89-
config = MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL)
96+
config = FP8GroupedMMConfig.from_recipe(FP8GroupedMMRecipe.FP8_ROWWISE)
9097
out = _quantize_then_scaled_grouped_mm(
9198
a,
9299
b_t,
@@ -105,15 +112,26 @@ def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
105112
out_dtype,
106113
offs,
107114
)
108-
assert torch.equal(out, ref_out)
109115

110116
# Run backward pass.
111117
out.sum().backward()
112118
ref_out.sum().backward()
113119

114120
# Validate gradients.
115-
assert torch.equal(a.grad, ref_a.grad)
116-
assert torch.equal(b_t.grad, ref_b_t.grad)
121+
if is_ROCM():
122+
# ROCm: reference vs tested path use different backends:
123+
# - `torch._scaled_mm` uses hipBLASLt
124+
# - `_quantize_then_scaled_grouped_mm` uses CK
125+
# Different backends can use different kernel implementations / accumulation order, so the
126+
# outputs can differ slightly and we need tolerance.
127+
# On MI300/MI325 we need rtol=atol=1e-2 for this FP8 test to pass.
128+
assert torch.allclose(out, ref_out, rtol=1e-2, atol=1e-2)
129+
assert torch.allclose(a.grad, ref_a.grad, rtol=1e-2, atol=1e-2)
130+
assert torch.allclose(b_t.grad, ref_b_t.grad, rtol=1e-2, atol=1e-2)
131+
else:
132+
assert torch.equal(out, ref_out)
133+
assert torch.equal(a.grad, ref_a.grad)
134+
assert torch.equal(b_t.grad, ref_b_t.grad)
117135

118136

119137
@skip_if_rocm("ROCm not supported")
@@ -180,7 +198,7 @@ def compute_reference_forward(
180198
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
181199
)
182200
A_scaled = A.to(torch.float32) * A_scales
183-
A_fp8 = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
201+
A_fp8 = to_fp8_saturated(A_scaled, float8_config.cast_config_input.target_dtype)
184202

185203
# Convert B^t to fp8.
186204
B_t_scales = tensor_to_scale(
@@ -193,7 +211,7 @@ def compute_reference_forward(
193211
B_t_scaled = B_t.to(torch.float32) * B_t_scales
194212
B_t_fp8 = to_fp8_saturated(
195213
B_t_scaled,
196-
torch.float8_e4m3fn,
214+
float8_config.cast_config_input.target_dtype,
197215
)
198216

199217
# Split A and result into chunks, one for each group.
@@ -231,8 +249,12 @@ def compute_reference_forward(
231249
LinearMMConfig(),
232250
float8_config,
233251
)
234-
assert torch.equal(result1, ref_group_result1)
235-
assert torch.equal(result2, ref_group_result2)
252+
if is_ROCM():
253+
assert torch.allclose(result1, ref_group_result1, rtol=1e-2, atol=1e-2)
254+
assert torch.allclose(result2, ref_group_result2, rtol=1e-2, atol=1e-2)
255+
else:
256+
assert torch.equal(result1, ref_group_result1)
257+
assert torch.equal(result2, ref_group_result2)
236258
outputs.append(ref_group_result2)
237259

238260
# Concatenate the outputs and verify the full result is correct.

test/prototype/moe_training/test_training.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from torchao.quantization.quant_api import quantize_
2222
from torchao.quantization.quantize_.common import KernelPreference
23+
from torchao.utils import is_MI300, is_MI350, is_ROCM
2324

2425
# Reference MoE implementation (copied from torchtitan to avoid external dependency)
2526
from .reference_moe import MoE, MoEArgs, set_token_group_alignment_size_m
@@ -97,18 +98,16 @@ def test_moe_training(
9798
"Skipping compile=True with kernel_preference=EMULATED, not currently supported"
9899
)
99100

100-
# FP8_ROWWISE hardware path requires SM90
101-
if (
102-
recipe == FP8GroupedMMRecipe.FP8_ROWWISE
103-
and torch.cuda.get_device_capability()
104-
!= (
105-
9,
106-
0,
107-
)
108-
):
109-
pytest.skip(
110-
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
111-
)
101+
# FP8_ROWWISE hardware path requires SM90 (CUDA) or MI300/MI350 (ROCm)
102+
if recipe == FP8GroupedMMRecipe.FP8_ROWWISE:
103+
if is_ROCM():
104+
if not (is_MI300() or is_MI350()):
105+
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")
106+
else:
107+
if torch.cuda.get_device_capability() != (9, 0):
108+
pytest.skip(
109+
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
110+
)
112111

113112
# MXFP8 hardware path requires SM100
114113
if recipe in (

torchao/prototype/moe_training/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torchao.prototype.mx_formats.config import ScaleCalculationMode
1616
from torchao.quantization.quantize_.common import KernelPreference
1717
from torchao.quantization.transform_module import register_quantize_module_handler
18-
from torchao.utils import register_as_pytree_constant
18+
from torchao.utils import is_MI300, register_as_pytree_constant
1919

2020

2121
class FP8GroupedMMRecipe(Enum):
@@ -45,6 +45,10 @@ class FP8GroupedMMConfig(GroupedMMConfig):
4545
Configuration for FP8 grouped matrix multiplication.
4646
"""
4747

48+
# Float8 dtype for the FP8 grouped GEMMs.
49+
float8_dtype: torch.dtype = (
50+
torch.float8_e4m3fnuz if is_MI300() else torch.float8_e4m3fn
51+
)
4852
# Output dtype for the FP8 grouped GEMMs.
4953
out_dtype: Optional[torch.dtype] = torch.bfloat16
5054

torchao/prototype/moe_training/fp8_grouped_mm.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
2222
B_t: torch.Tensor,
2323
offs: torch.Tensor,
2424
out_dtype: Optional[torch.dtype] = torch.bfloat16,
25+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
2526
) -> torch.Tensor:
2627
"""
2728
Differentiable FP8 grouped matrix multiplication with dynamic FP8 rowwise quantization.
@@ -48,7 +49,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
4849
- Scales are computed per-row and rounded to powers of 2 for efficiency
4950
- This function is fully differentiable via custom autograd implementation
5051
"""
51-
return _Float8GroupedMM.apply(A, B_t, offs, out_dtype)
52+
return _Float8GroupedMM.apply(A, B_t, offs, out_dtype, float8_dtype)
5253

5354

5455
class _Float8GroupedMM(torch.autograd.Function):
@@ -61,6 +62,7 @@ def forward(
6162
B_t: torch.Tensor,
6263
offs: Optional[torch.Tensor] = None,
6364
out_dtype: Optional[torch.dtype] = torch.bfloat16,
65+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
6466
) -> torch.Tensor:
6567
# torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
6668
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
@@ -100,31 +102,32 @@ def forward(
100102
# A_scales shape: (M,1) or (B, M, 1)
101103
A_scales = tensor_to_scale(
102104
A,
103-
torch.float8_e4m3fn,
105+
float8_dtype,
104106
scaling_granularity=ScalingGranularity.AXISWISE,
105107
axiswise_dim=-1,
106108
round_scales_to_power_of_2=True,
107109
)
108110
A_scaled = A.to(torch.float32) * A_scales
109-
A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn)
111+
A_data_row_major = to_fp8_saturated(A_scaled, float8_dtype)
110112

111113
# Convert B to float8, column-major for right operand of grouped GEMM.
112114
# B_t shape: (E, K, N)
113115
# B_t scales must be computed rowwise keeping the outer/final dim, so:
114116
# B_t_scales shape: (E, 1, N)
115117
B_t_scales = tensor_to_scale(
116118
B_t,
117-
torch.float8_e4m3fn,
119+
float8_dtype,
118120
scaling_granularity=ScalingGranularity.AXISWISE,
119121
axiswise_dim=-2,
120122
round_scales_to_power_of_2=True,
121123
)
122124
B_t_scaled = B_t.to(torch.float32) * B_t_scales
123-
B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn)
125+
B_t_data_col_major = to_fp8_saturated(B_t_scaled, float8_dtype)
124126

125127
# Store what we need for backward.
126128
ctx.save_for_backward(A, B_t, offs)
127129
ctx.out_dtype = out_dtype
130+
ctx.float8_dtype = float8_dtype
128131

129132
# Perform scaled grouped GEMM and return result.
130133
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
@@ -154,6 +157,7 @@ def forward(
154157
def backward(ctx, grad_output: torch.Tensor):
155158
A, B_t, offs = ctx.saved_tensors
156159
out_dtype = ctx.out_dtype
160+
float8_dtype = ctx.float8_dtype
157161

158162
# Convert grad_output to float8, row-major for left operand of grouped GEMM
159163
# needed for grad_A: grad_output @ B
@@ -162,21 +166,19 @@ def backward(ctx, grad_output: torch.Tensor):
162166
# grad_output_scale shape: (Mg, 1)
163167
grad_output_scales = tensor_to_scale(
164168
grad_output,
165-
torch.float8_e4m3fn,
169+
float8_dtype,
166170
scaling_granularity=ScalingGranularity.AXISWISE,
167171
axiswise_dim=-1,
168172
round_scales_to_power_of_2=True,
169173
)
170174
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
171-
grad_output_data_row_major = to_fp8_saturated(
172-
grad_output_scaled, torch.float8_e4m3fn
173-
)
175+
grad_output_data_row_major = to_fp8_saturated(grad_output_scaled, float8_dtype)
174176

175177
# Compute B fp8 column-major for right operand of grouped GEMM:
176178
# grad_A = grad_output @ B.
177179
B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
178180
B_t._data if hasattr(B_t, "_data") else B_t,
179-
output_dtype=torch.float8_e4m3fn,
181+
output_dtype=float8_dtype,
180182
round_scales_to_power_of_2=True,
181183
)
182184

@@ -216,7 +218,7 @@ def backward(ctx, grad_output: torch.Tensor):
216218
.contiguous()
217219
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
218220
offs,
219-
torch.float8_e4m3fn,
221+
float8_dtype,
220222
round_scales_to_power_of_2=True,
221223
)
222224
grad_output_t_data_row_major = grad_out_data_colwise.t()
@@ -227,7 +229,7 @@ def backward(ctx, grad_output: torch.Tensor):
227229
.contiguous()
228230
.t(), # Quantization is over 2x faster when input is col major, even with this transformation
229231
offs,
230-
torch.float8_e4m3fn,
232+
float8_dtype,
231233
round_scales_to_power_of_2=True,
232234
)
233235

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424
torch.int32: tl.int32,
2525
torch.int64: tl.int64,
2626
torch.float8_e4m3fn: tl.float8e4nv,
27+
torch.float8_e4m3fnuz: tl.float8e4b8,
2728
torch.float8_e5m2: tl.float8e5,
29+
torch.float8_e5m2fnuz: tl.float8e5b16,
2830
torch.float16: tl.float16,
2931
torch.bfloat16: tl.bfloat16,
3032
torch.float32: tl.float32,

torchao/prototype/moe_training/kernels/jagged_float8_scales.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
torch.int32: tl.int32,
3030
torch.int64: tl.int64,
3131
torch.float8_e4m3fn: tl.float8e4nv,
32+
torch.float8_e4m3fnuz: tl.float8e4b8,
3233
torch.float8_e5m2: tl.float8e5,
34+
torch.float8_e5m2fnuz: tl.float8e5b16,
3335
torch.float16: tl.float16,
3436
torch.bfloat16: tl.bfloat16,
3537
torch.float32: tl.float32,

torchao/prototype/moe_training/tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def _quantize_then_scaled_grouped_mm(
264264
B_t,
265265
offs,
266266
config.out_dtype,
267+
config.float8_dtype,
267268
)
268269
elif isinstance(config, MXFP8GroupedMMConfig):
269270
return _to_mxfp8_then_scaled_grouped_mm(

0 commit comments

Comments
 (0)