Skip to content

Commit b23fff8

Browse files
Skip blockwise FP8 GEMM tests on ROCm due to numerical issues
Per reviewer feedback, skip the two GEMM tests on ROCm rather than using a heavily relaxed SQNR threshold (0.5 vs 28.0). The blockwise quantization kernel tests remain enabled on ROCm.
1 parent b764ffb commit b23fff8

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

test/prototype/blockwise_fp8_training/test_blockwise_kernels.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
not (is_sm_at_least_90() or is_MI300() or is_MI350()),
4444
reason="Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
4545
)
46+
@pytest.mark.skipif(is_ROCM(), reason="Blockwise FP8 GEMM has numerical issues on ROCm")
4647
@pytest.mark.skipif(
4748
version.parse(triton.__version__) < version.parse("3.3.0"),
4849
reason="Triton version < 3.3.0, test skipped",
@@ -61,10 +62,7 @@ def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype):
6162
assert not C_q.isnan().any(), "C_q must not contain NaNs"
6263

6364
sqnr = compute_error(C, C_q)
64-
# e4m3fnuz (ROCm) has lower dynamic range (±240) than e4m3fn (CUDA, ±448),
65-
# causing worse quantization error for small-M shapes where errors don't
66-
# average out. Use a relaxed threshold on ROCm.
67-
min_sqnr = 0.5 if is_ROCM() else 28.0
65+
min_sqnr = 28.0
6866
assert sqnr >= min_sqnr, f"SQNR {sqnr:.2f} must be >= {min_sqnr}"
6967

7068

@@ -73,6 +71,7 @@ def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype):
7371
not (is_sm_at_least_90() or is_MI300() or is_MI350()),
7472
reason="Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
7573
)
74+
@pytest.mark.skipif(is_ROCM(), reason="Blockwise FP8 GEMM has numerical issues on ROCm")
7675
@pytest.mark.skipif(
7776
version.parse(triton.__version__) < version.parse("3.3.0"),
7877
reason="Triton version < 3.3.0, test skipped",

0 commit comments

Comments
 (0)