Skip to content

Commit 8ef6d80

Browse files
committed
enablement of MoE fp8/mxfp8/nvfp4 tests
1 parent 721d69c commit 8ef6d80

8 files changed

Lines changed: 178 additions & 97 deletions

File tree

test/prototype/moe_training/test_fp8_grouped_mm.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
is_MI350,
1313
is_sm_at_least_90,
1414
is_sm_version,
15+
is_XPU,
1516
torch_version_at_least,
1617
)
1718

18-
if not (
19-
torch_version_at_least("2.7.0")
20-
and torch.cuda.is_available()
21-
and (is_sm_at_least_90() or is_MI300() or is_MI350())
22-
):
19+
_is_xpu_available = is_XPU()
20+
_is_compatible_cuda = torch.cuda.is_available() and (
21+
is_sm_at_least_90() or is_MI300() or is_MI350()
22+
)
23+
24+
if not ((_is_xpu_available or _is_compatible_cuda) and torch_version_at_least("2.7.0")):
2325
pytest.skip(
24-
"Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
26+
"Requires FP8-capable GPU (CUDA SM90+, MI300, MI350 or XPU)",
2527
allow_module_level=True,
2628
)
2729

@@ -41,11 +43,18 @@
4143
from torchao.prototype.moe_training.fp8_grouped_mm import (
4244
_to_fp8_rowwise_then_scaled_grouped_mm,
4345
)
44-
from torchao.utils import is_MI300, is_MI350, is_ROCM
46+
from torchao.utils import get_available_devices, is_MI300, is_MI350, is_ROCM
4547

4648
# Needed since changing args to function causes recompiles
4749
torch._dynamo.config.cache_size_limit = 1000
4850

51+
_DEVICES = get_available_devices()[1:]
52+
53+
54+
@pytest.fixture(scope="module", params=_DEVICES)
55+
def device(request):
56+
return request.param
57+
4958

5059
@pytest.mark.skipif(
5160
True,
@@ -55,7 +64,7 @@
5564
@pytest.mark.parametrize("n", [8192])
5665
@pytest.mark.parametrize("k", [5120])
5766
@pytest.mark.parametrize("n_groups", [1, 2, 4, 8])
58-
def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
67+
def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups, device):
5968
if is_ROCM():
6069
if not (is_MI300() or is_MI350()):
6170
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")
@@ -64,7 +73,6 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
6473
pytest.skip("FP8 rowwise test requires SM 9.0 on CUDA")
6574

6675
out_dtype = torch.bfloat16
67-
device = "cuda"
6876
a = torch.randn(
6977
m * n_groups,
7078
k,
@@ -79,7 +87,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
7987
device=device,
8088
dtype=torch.bfloat16,
8189
)
82-
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
90+
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)
8391

8492
# b must be transposed and in column major format.
8593
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)
@@ -91,7 +99,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
9199
b_t,
92100
offs=offs,
93101
out_dtype=config.out_dtype,
94-
float8_dtyep=config.float8_dtype,
102+
float8_dtype=config.float8_dtype,
95103
)
96104

97105
# Validate result.
@@ -111,7 +119,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
111119
ref_out.sum().backward()
112120

113121
# Validate gradients.
114-
if is_ROCM():
122+
if is_ROCM() or _is_xpu_available:
115123
# ROCm: reference vs tested path use different backends:
116124
# - `torch._scaled_mm` uses hipBLASLt
117125
# - `_to_fp8_rowwise_then_scaled_grouped_mm` uses CK
@@ -130,15 +138,14 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
130138
@pytest.mark.parametrize("m", [16, 17])
131139
@pytest.mark.parametrize("k", [16, 18])
132140
@pytest.mark.parametrize("n", [32, 33])
133-
def test_K_or_N_dim_not_multiple_of_16(m, n, k):
141+
def test_K_or_N_dim_not_multiple_of_16(m, n, k, device):
134142
# - Leading dim of A doesn't have to be divisible by 16, since it will be
135143
# divided up into groups based on offset anyway.
136144
# - Trailing dim of A must be divisible by 16.
137145
# - Leading dim of B (n_groups) doesn't need to be divisible by 16.
138146
# - Last 2 dims of B must be divisible by 16.
139147
if n % 16 == 0 and k % 16 == 0:
140148
return
141-
device = "cuda"
142149
n_groups = 4
143150
a = torch.randn(
144151
m * n_groups,
@@ -161,7 +168,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):
161168
b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1)
162169

163170
config = Float8TrainingOpConfig.from_recipe(Float8TrainingRecipe.FP8_ROWWISE)
164-
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
171+
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)
165172

166173
# Compute output.
167174
with pytest.raises(AssertionError):

test/prototype/moe_training/test_mxfp8_grouped_mm.py

Lines changed: 56 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,17 @@
1414
is_MI350,
1515
is_sm_at_least_90,
1616
is_sm_version,
17+
is_XPU,
1718
torch_version_at_least,
1819
)
1920

20-
if not (
21-
torch_version_at_least("2.7.0")
22-
and torch.cuda.is_available()
23-
and (is_sm_at_least_90() or is_MI300() or is_MI350())
24-
):
21+
_is_xpu = is_XPU()
22+
_is_compatible_cuda = torch.cuda.is_available() and (
23+
is_sm_at_least_90() or is_MI300() or is_MI350()
24+
)
25+
if not ((_is_xpu or _is_compatible_cuda) and torch_version_at_least("2.7.0")):
2526
pytest.skip(
26-
"Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
27+
"Requires FP8-capable GPU (CUDA SM90+, MI300, MI350, or XPU) and PyTorch 2.7+",
2728
allow_module_level=True,
2829
)
2930

@@ -47,15 +48,24 @@
4748
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
4849
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
4950
from torchao.quantization.quantize_.common import KernelPreference
50-
from torchao.testing.utils import skip_if_rocm
51+
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
52+
from torchao.utils import get_available_devices
53+
54+
_DEVICES = get_available_devices()[1:]
55+
56+
57+
@pytest.fixture(scope="module", params=_DEVICES)
58+
def device(request):
59+
return request.param
60+
5161

5262
# Needed since changing args to function causes recompiles
5363
torch._dynamo.config.cache_size_limit = 1000
5464

5565

5666
@skip_if_rocm("ROCm not supported")
5767
@pytest.mark.skipif(
58-
not is_sm_version(10, 0),
68+
torch.cuda.is_available() and not is_sm_version(10, 0),
5969
reason="3D MXFP8 quantization requires SM100",
6070
)
6171
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@@ -65,11 +75,11 @@
6575
"scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
6676
)
6777
def test_emulate_mxfp8_grouped_gemm_2d_3d(
68-
M, K, N, num_experts, scale_block_k, scale_mode
78+
M, K, N, num_experts, scale_block_k, scale_mode, device
6979
):
70-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
71-
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
72-
offs = generate_jagged_offs(num_experts, M)
80+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
81+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device=device)
82+
offs = generate_jagged_offs(num_experts, M, device=device)
7383
offs_ref = offs.clone()
7484

7585
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
@@ -120,8 +130,9 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(
120130

121131

122132
@skip_if_rocm("ROCm not supported")
133+
@skip_if_xpu("XPU support not yet available")
123134
@pytest.mark.skipif(
124-
not is_sm_version(10, 0),
135+
torch.cuda.is_available() and not is_sm_version(10, 0),
125136
reason="3D MXFP8 quantization and MXFP8 grouped GEMM require SM100",
126137
)
127138
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
@@ -130,10 +141,12 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(
130141
@pytest.mark.parametrize(
131142
"scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
132143
)
133-
def test_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts, scale_block_k, scale_mode):
134-
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
135-
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
136-
offs = generate_jagged_offs(num_experts, M)
144+
def test_mxfp8_grouped_gemm_2d_3d(
145+
M, K, N, num_experts, scale_block_k, scale_mode, device
146+
):
147+
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=device)
148+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device=device)
149+
offs = generate_jagged_offs(num_experts, M, device=device)
137150
offs_ref = offs.clone()
138151

139152
# Real SM100 grouped MM: 1x32-scaled A @ 3D-quantized B -> grad_input.
@@ -193,13 +206,13 @@ def test_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts, scale_block_k, scale_mod
193206
@pytest.mark.parametrize("M", (1024, 4096))
194207
@pytest.mark.parametrize("N", (1024, 4096))
195208
@pytest.mark.parametrize("num_experts", (8, 16))
196-
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
209+
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts, device):
197210
# Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
198211
block_size = 32
199-
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
212+
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=device)
200213
grad_out_t = grad_out.t().contiguous()
201-
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
202-
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
214+
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
215+
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size, device=device)
203216
x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone()
204217

205218
# bf16 reference grouped gemm
@@ -238,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
238251

239252

240253
@skip_if_rocm("ROCm not supported")
254+
@skip_if_xpu("XPU support not yet available")
241255
@pytest.mark.parametrize("M,K,N", [(32768, 5120, 8192), (16640, 7168, 2048)])
242256
@pytest.mark.parametrize("num_experts", (1, 8))
243257
@pytest.mark.parametrize("wgrad_with_hp", (True, False))
@@ -257,9 +271,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
257271
use_compile,
258272
kernel_preference,
259273
scale_mode,
274+
device,
260275
):
261276
# MXFP8 hardware path requires SM100
262-
if kernel_preference != KernelPreference.EMULATED and not is_sm_version(10, 0):
277+
if (
278+
torch.cuda.is_available()
279+
and kernel_preference != KernelPreference.EMULATED
280+
and not is_sm_version(10, 0)
281+
):
263282
pytest.skip(
264283
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
265284
)
@@ -268,17 +287,17 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
268287
"torch native dynamic per group pad/unpad functions do not work with torch.compile yet: https://github.com/pytorch/pytorch/issues/176770"
269288
)
270289

271-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
290+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
272291
w = torch.randn(
273292
num_experts,
274293
N,
275294
K,
276295
dtype=torch.bfloat16,
277-
device="cuda",
296+
device=device,
278297
)
279298
w_t = w.transpose(-2, -1).requires_grad_(True)
280299

281-
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
300+
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)
282301
x_ref, w_t_ref, offs_ref = (
283302
x.clone().detach().requires_grad_(True),
284303
w_t.clone().detach().requires_grad_(True),
@@ -328,19 +347,19 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
328347

329348

330349
@skip_if_rocm("ROCm not supported")
331-
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
350+
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic(device):
332351
block_size = 32
333352
M, K, N, num_experts = 4096, 1024, 2048, 8
334-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
353+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
335354
w = torch.randn(
336355
num_experts,
337356
N,
338357
K,
339358
dtype=torch.bfloat16,
340-
device="cuda",
359+
device=device,
341360
)
342361
w_t = w.transpose(-2, -1).requires_grad_(True)
343-
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
362+
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)
344363

345364
x_ref = x.detach().clone().requires_grad_(True)
346365
w_t_ref = w_t.detach().clone().requires_grad_(True)
@@ -401,19 +420,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
401420

402421

403422
@skip_if_rocm("ROCm not supported")
404-
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
423+
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward(device):
405424
block_size = 32
406425
M, K, N, num_experts = 4096, 1024, 2048, 8
407-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
426+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
408427
w = torch.randn(
409428
num_experts,
410429
N,
411430
K,
412431
dtype=torch.bfloat16,
413-
device="cuda",
432+
device=device,
414433
)
415434
w_t = w.transpose(-2, -1)
416-
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
435+
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)
417436

418437
x_scale, x_qdata = to_mx(
419438
x.detach(),
@@ -455,19 +474,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
455474

456475

457476
@skip_if_rocm("ROCm not supported")
458-
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp():
477+
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp(device):
459478
block_size = 32
460479
M, K, N, num_experts = 1024, 1024, 2048, 4
461-
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
480+
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
462481
w = torch.randn(
463482
num_experts,
464483
N,
465484
K,
466485
dtype=torch.bfloat16,
467-
device="cuda",
486+
device=device,
468487
)
469488
w_t = w.transpose(-2, -1)
470-
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
489+
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)
471490

472491
x_scale, x_qdata = to_mx(
473492
x,

0 commit comments

Comments
 (0)