1414 is_MI350 ,
1515 is_sm_at_least_90 ,
1616 is_sm_version ,
17+ is_XPU ,
1718 torch_version_at_least ,
1819)
1920
20- if not (
21- torch_version_at_least ( "2.7.0" )
22- and torch . cuda . is_available ()
23- and ( is_sm_at_least_90 () or is_MI300 () or is_MI350 () )
24- ):
21+ _is_xpu = is_XPU ()
22+ _is_compatible_cuda = torch . cuda . is_available () and (
23+ is_sm_at_least_90 () or is_MI300 () or is_MI350 ()
24+ )
25+ if not (( _is_xpu or _is_compatible_cuda ) and torch_version_at_least ( "2.7.0" ) ):
2526 pytest .skip (
26- "Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350) " ,
27+ "Requires FP8-capable GPU (CUDA SM90+, MI300, MI350, or XPU) and PyTorch 2.7+ " ,
2728 allow_module_level = True ,
2829 )
2930
4748from torchao .prototype .mx_formats .kernels import triton_to_mxfp8_dim0
4849from torchao .prototype .mx_formats .mx_tensor import MXTensor , to_mx
4950from torchao .quantization .quantize_ .common import KernelPreference
50- from torchao .testing .utils import skip_if_rocm
51+ from torchao .testing .utils import skip_if_rocm , skip_if_xpu
52+ from torchao .utils import get_available_devices
53+
54+ _DEVICES = get_available_devices ()[1 :]
55+
56+
57+ @pytest .fixture (scope = "module" , params = _DEVICES )
58+ def device (request ):
59+ return request .param
60+
5161
5262# Needed since changing args to function causes recompiles
5363torch ._dynamo .config .cache_size_limit = 1000
5464
5565
5666@skip_if_rocm ("ROCm not supported" )
5767@pytest .mark .skipif (
58- not is_sm_version (10 , 0 ),
68+ torch . cuda . is_available () and not is_sm_version (10 , 0 ),
5969 reason = "3D MXFP8 quantization requires SM100" ,
6070)
6171@pytest .mark .parametrize ("M,K,N" , [(1024 , 1024 , 1024 ), (1024 , 2048 , 4096 )])
6575 "scale_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
6676)
6777def test_emulate_mxfp8_grouped_gemm_2d_3d (
68- M , K , N , num_experts , scale_block_k , scale_mode
78+ M , K , N , num_experts , scale_block_k , scale_mode , device
6979):
70- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
71- w = torch .randn (num_experts , N , K , dtype = torch .bfloat16 , device = "cuda" )
72- offs = generate_jagged_offs (num_experts , M )
80+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
81+ w = torch .randn (num_experts , N , K , dtype = torch .bfloat16 , device = device )
82+ offs = generate_jagged_offs (num_experts , M , device = device )
7383 offs_ref = offs .clone ()
7484
7585 # Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
@@ -120,8 +130,9 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(
120130
121131
122132@skip_if_rocm ("ROCm not supported" )
133+ @skip_if_xpu ("XPU support not yet available" )
123134@pytest .mark .skipif (
124- not is_sm_version (10 , 0 ),
135+ torch . cuda . is_available () and not is_sm_version (10 , 0 ),
125136 reason = "3D MXFP8 quantization and MXFP8 grouped GEMM require SM100" ,
126137)
127138@pytest .mark .parametrize ("M,K,N" , [(1024 , 1024 , 1024 ), (1024 , 2048 , 4096 )])
@@ -130,10 +141,12 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(
130141@pytest .mark .parametrize (
131142 "scale_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
132143)
133- def test_mxfp8_grouped_gemm_2d_3d (M , K , N , num_experts , scale_block_k , scale_mode ):
134- grad_out = torch .randn (M , N , dtype = torch .bfloat16 , device = "cuda" )
135- w = torch .randn (num_experts , N , K , dtype = torch .bfloat16 , device = "cuda" )
136- offs = generate_jagged_offs (num_experts , M )
144+ def test_mxfp8_grouped_gemm_2d_3d (
145+ M , K , N , num_experts , scale_block_k , scale_mode , device
146+ ):
147+ grad_out = torch .randn (M , N , dtype = torch .bfloat16 , device = device )
148+ w = torch .randn (num_experts , N , K , dtype = torch .bfloat16 , device = device )
149+ offs = generate_jagged_offs (num_experts , M , device = device )
137150 offs_ref = offs .clone ()
138151
139152 # Real SM100 grouped MM: 1x32-scaled A @ 3D-quantized B -> grad_input.
@@ -193,13 +206,13 @@ def test_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts, scale_block_k, scale_mod
193206@pytest .mark .parametrize ("M" , (1024 , 4096 ))
194207@pytest .mark .parametrize ("N" , (1024 , 4096 ))
195208@pytest .mark .parametrize ("num_experts" , (8 , 16 ))
196- def test_emulate_mxfp8_grouped_gemm_2d_2d (M , N , num_experts ):
209+ def test_emulate_mxfp8_grouped_gemm_2d_2d (M , N , num_experts , device ):
197210 # Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
198211 block_size = 32
199- grad_out = torch .randn (M , N , dtype = torch .bfloat16 , device = "cuda" )
212+ grad_out = torch .randn (M , N , dtype = torch .bfloat16 , device = device )
200213 grad_out_t = grad_out .t ().contiguous ()
201- x = torch .randn (M , N , dtype = torch .bfloat16 , device = "cuda" )
202- offs = generate_jagged_offs (num_experts , M , multiple_of = block_size )
214+ x = torch .randn (M , N , dtype = torch .bfloat16 , device = device )
215+ offs = generate_jagged_offs (num_experts , M , multiple_of = block_size , device = device )
203216 x_ref , grad_out_t_ref , offs_ref = x .clone (), grad_out_t .clone (), offs .clone ()
204217
205218 # bf16 reference grouped gemm
@@ -238,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
238251
239252
240253@skip_if_rocm ("ROCm not supported" )
254+ @skip_if_xpu ("XPU support not yet available" )
241255@pytest .mark .parametrize ("M,K,N" , [(32768 , 5120 , 8192 ), (16640 , 7168 , 2048 )])
242256@pytest .mark .parametrize ("num_experts" , (1 , 8 ))
243257@pytest .mark .parametrize ("wgrad_with_hp" , (True , False ))
@@ -257,9 +271,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
257271 use_compile ,
258272 kernel_preference ,
259273 scale_mode ,
274+ device ,
260275):
261276 # MXFP8 hardware path requires SM100
262- if kernel_preference != KernelPreference .EMULATED and not is_sm_version (10 , 0 ):
277+ if (
278+ torch .cuda .is_available ()
279+ and kernel_preference != KernelPreference .EMULATED
280+ and not is_sm_version (10 , 0 )
281+ ):
263282 pytest .skip (
264283 f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found { torch .cuda .get_device_capability ()} "
265284 )
@@ -268,17 +287,17 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
268287 "torch native dynamic per group pad/unpad functions do not work with torch.compile yet: https://github.com/pytorch/pytorch/issues/176770"
269288 )
270289
271- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" , requires_grad = True )
290+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = device , requires_grad = True )
272291 w = torch .randn (
273292 num_experts ,
274293 N ,
275294 K ,
276295 dtype = torch .bfloat16 ,
277- device = "cuda" ,
296+ device = device ,
278297 )
279298 w_t = w .transpose (- 2 , - 1 ).requires_grad_ (True )
280299
281- offs = generate_jagged_offs (num_experts , M , multiple_of = 128 )
300+ offs = generate_jagged_offs (num_experts , M , multiple_of = 128 , device = device )
282301 x_ref , w_t_ref , offs_ref = (
283302 x .clone ().detach ().requires_grad_ (True ),
284303 w_t .clone ().detach ().requires_grad_ (True ),
@@ -328,19 +347,19 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
328347
329348
330349@skip_if_rocm ("ROCm not supported" )
331- def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic ():
350+ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic (device ):
332351 block_size = 32
333352 M , K , N , num_experts = 4096 , 1024 , 2048 , 8
334- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" , requires_grad = True )
353+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = device , requires_grad = True )
335354 w = torch .randn (
336355 num_experts ,
337356 N ,
338357 K ,
339358 dtype = torch .bfloat16 ,
340- device = "cuda" ,
359+ device = device ,
341360 )
342361 w_t = w .transpose (- 2 , - 1 ).requires_grad_ (True )
343- offs = generate_jagged_offs (num_experts , M , multiple_of = 128 )
362+ offs = generate_jagged_offs (num_experts , M , multiple_of = 128 , device = device )
344363
345364 x_ref = x .detach ().clone ().requires_grad_ (True )
346365 w_t_ref = w_t .detach ().clone ().requires_grad_ (True )
@@ -401,19 +420,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
401420
402421
403422@skip_if_rocm ("ROCm not supported" )
404- def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward ():
423+ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward (device ):
405424 block_size = 32
406425 M , K , N , num_experts = 4096 , 1024 , 2048 , 8
407- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
426+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
408427 w = torch .randn (
409428 num_experts ,
410429 N ,
411430 K ,
412431 dtype = torch .bfloat16 ,
413- device = "cuda" ,
432+ device = device ,
414433 )
415434 w_t = w .transpose (- 2 , - 1 )
416- offs = generate_jagged_offs (num_experts , M , multiple_of = 128 )
435+ offs = generate_jagged_offs (num_experts , M , multiple_of = 128 , device = device )
417436
418437 x_scale , x_qdata = to_mx (
419438 x .detach (),
@@ -455,19 +474,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
455474
456475
457476@skip_if_rocm ("ROCm not supported" )
458- def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp ():
477+ def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp (device ):
459478 block_size = 32
460479 M , K , N , num_experts = 1024 , 1024 , 2048 , 4
461- x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
480+ x = torch .randn (M , K , dtype = torch .bfloat16 , device = device )
462481 w = torch .randn (
463482 num_experts ,
464483 N ,
465484 K ,
466485 dtype = torch .bfloat16 ,
467- device = "cuda" ,
486+ device = device ,
468487 )
469488 w_t = w .transpose (- 2 , - 1 )
470- offs = generate_jagged_offs (num_experts , M , multiple_of = 128 )
489+ offs = generate_jagged_offs (num_experts , M , multiple_of = 128 , device = device )
471490
472491 x_scale , x_qdata = to_mx (
473492 x ,
0 commit comments