Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions test/prototype/moe_training/test_fp8_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@
is_MI350,
is_sm_at_least_90,
is_sm_version,
is_XPU,
torch_version_at_least,
)

if not (
torch_version_at_least("2.7.0")
and torch.cuda.is_available()
and (is_sm_at_least_90() or is_MI300() or is_MI350())
):
_is_xpu_available = is_XPU()
_is_compatible_cuda = torch.cuda.is_available() and (
is_sm_at_least_90() or is_MI300() or is_MI350()
)

if not ((_is_xpu_available or _is_compatible_cuda) and torch_version_at_least("2.7.0")):
pytest.skip(
"Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
"Requires FP8-capable GPU (CUDA SM90+, MI300, MI350 or XPU)",
allow_module_level=True,
)

Expand All @@ -41,11 +43,18 @@
from torchao.prototype.moe_training.fp8_grouped_mm import (
_to_fp8_rowwise_then_scaled_grouped_mm,
)
from torchao.utils import is_MI300, is_MI350, is_ROCM
from torchao.utils import get_available_devices, is_MI300, is_MI350, is_ROCM

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000

_DEVICES = get_available_devices()[1:]


@pytest.fixture(scope="module", params=_DEVICES)
def device(request):
return request.param


@pytest.mark.skipif(
True,
Expand All @@ -55,7 +64,7 @@
@pytest.mark.parametrize("n", [8192])
@pytest.mark.parametrize("k", [5120])
@pytest.mark.parametrize("n_groups", [1, 2, 4, 8])
def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups, device):
if is_ROCM():
if not (is_MI300() or is_MI350()):
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")
Expand All @@ -64,7 +73,6 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
pytest.skip("FP8 rowwise test requires SM 9.0 on CUDA")

out_dtype = torch.bfloat16
device = "cuda"
a = torch.randn(
m * n_groups,
k,
Expand All @@ -79,7 +87,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
device=device,
dtype=torch.bfloat16,
)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)

# b must be transposed and in column major format.
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)
Expand All @@ -91,7 +99,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
b_t,
offs=offs,
out_dtype=config.out_dtype,
float8_dtyep=config.float8_dtype,
float8_dtype=config.float8_dtype,
)

# Validate result.
Expand All @@ -111,7 +119,7 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
ref_out.sum().backward()

# Validate gradients.
if is_ROCM():
if is_ROCM() or _is_xpu_available:
# ROCm: reference vs tested path use different backends:
# - `torch._scaled_mm` uses hipBLASLt
# - `_to_fp8_rowwise_then_scaled_grouped_mm` uses CK
Expand All @@ -130,15 +138,14 @@ def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups):
@pytest.mark.parametrize("m", [16, 17])
@pytest.mark.parametrize("k", [16, 18])
@pytest.mark.parametrize("n", [32, 33])
def test_K_or_N_dim_not_multiple_of_16(m, n, k):
def test_K_or_N_dim_not_multiple_of_16(m, n, k, device):
# - Leading dim of A doesn't have to be divisible by 16, since it will be
# divided up into groups based on offset anyway.
# - Trailing dim of A must be divisible by 16.
# - Leading dim of B (n_groups) doesn't need to be divisible by 16.
# - Last 2 dims of B must be divisible by 16.
if n % 16 == 0 and k % 16 == 0:
return
device = "cuda"
n_groups = 4
a = torch.randn(
m * n_groups,
Expand All @@ -161,7 +168,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):
b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1)

config = Float8TrainingOpConfig.from_recipe(Float8TrainingRecipe.FP8_ROWWISE)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
offs = torch.arange(m, n_groups * m + 1, m, device=device, dtype=torch.int32)

# Compute output.
with pytest.raises(AssertionError):
Expand Down
93 changes: 56 additions & 37 deletions test/prototype/moe_training/test_mxfp8_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
is_MI350,
is_sm_at_least_90,
is_sm_version,
is_XPU,
torch_version_at_least,
)

if not (
torch_version_at_least("2.7.0")
and torch.cuda.is_available()
and (is_sm_at_least_90() or is_MI300() or is_MI350())
):
_is_xpu = is_XPU()
_is_compatible_cuda = torch.cuda.is_available() and (
is_sm_at_least_90() or is_MI300() or is_MI350()
)
if not ((_is_xpu or _is_compatible_cuda) and torch_version_at_least("2.7.0")):
pytest.skip(
"Requires FP8-capable GPU (CUDA SM90+, MI300, or MI350)",
"Requires FP8-capable GPU (CUDA SM90+, MI300, MI350, or XPU) and PyTorch 2.7+",
allow_module_level=True,
)

Expand All @@ -47,15 +48,24 @@
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
from torchao.quantization.quantize_.common import KernelPreference
from torchao.testing.utils import skip_if_rocm
from torchao.testing.utils import skip_if_rocm, skip_if_xpu
from torchao.utils import get_available_devices

_DEVICES = get_available_devices()[1:]


@pytest.fixture(scope="module", params=_DEVICES)
def device(request):
return request.param


# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@skip_if_rocm("ROCm not supported")
@pytest.mark.skipif(
not is_sm_version(10, 0),
torch.cuda.is_available() and not is_sm_version(10, 0),
reason="3D MXFP8 quantization requires SM100",
)
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
Expand All @@ -65,11 +75,11 @@
"scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_emulate_mxfp8_grouped_gemm_2d_3d(
M, K, N, num_experts, scale_block_k, scale_mode
M, K, N, num_experts, scale_block_k, scale_mode, device
):
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device=device)
offs = generate_jagged_offs(num_experts, M, device=device)
offs_ref = offs.clone()

# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
Expand Down Expand Up @@ -120,8 +130,9 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(


@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU support not yet available")
@pytest.mark.skipif(
not is_sm_version(10, 0),
torch.cuda.is_available() and not is_sm_version(10, 0),
reason="3D MXFP8 quantization and MXFP8 grouped GEMM require SM100",
)
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
Expand All @@ -130,10 +141,12 @@ def test_emulate_mxfp8_grouped_gemm_2d_3d(
@pytest.mark.parametrize(
"scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts, scale_block_k, scale_mode):
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M)
def test_mxfp8_grouped_gemm_2d_3d(
M, K, N, num_experts, scale_block_k, scale_mode, device
):
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=device)
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device=device)
offs = generate_jagged_offs(num_experts, M, device=device)
offs_ref = offs.clone()

# Real SM100 grouped MM: 1x32-scaled A @ 3D-quantized B -> grad_input.
Expand Down Expand Up @@ -193,13 +206,13 @@ def test_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts, scale_block_k, scale_mod
@pytest.mark.parametrize("M", (1024, 4096))
@pytest.mark.parametrize("N", (1024, 4096))
@pytest.mark.parametrize("num_experts", (8, 16))
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts, device):
# Simluate 2d-2d grouped gemm grad_weight = grad_output_t @ x
block_size = 32
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
grad_out = torch.randn(M, N, dtype=torch.bfloat16, device=device)
grad_out_t = grad_out.t().contiguous()
x = torch.randn(M, N, dtype=torch.bfloat16, device="cuda")
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
x = torch.randn(M, N, dtype=torch.bfloat16, device=device)
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size, device=device)
x_ref, grad_out_t_ref, offs_ref = x.clone(), grad_out_t.clone(), offs.clone()

# bf16 reference grouped gemm
Expand Down Expand Up @@ -238,6 +251,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):


@skip_if_rocm("ROCm not supported")
@skip_if_xpu("XPU support not yet available")
@pytest.mark.parametrize("M,K,N", [(32768, 5120, 8192), (16640, 7168, 2048)])
@pytest.mark.parametrize("num_experts", (1, 8))
@pytest.mark.parametrize("wgrad_with_hp", (True, False))
Expand All @@ -257,9 +271,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
use_compile,
kernel_preference,
scale_mode,
device,
):
# MXFP8 hardware path requires SM100
if kernel_preference != KernelPreference.EMULATED and not is_sm_version(10, 0):
if (
torch.cuda.is_available()
and kernel_preference != KernelPreference.EMULATED
and not is_sm_version(10, 0)
):
pytest.skip(
f"Skipping MXFP8 hardware mode tests, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
)
Expand All @@ -268,17 +287,17 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
"torch native dynamic per group pad/unpad functions do not work with torch.compile yet: https://github.com/pytorch/pytorch/issues/176770"
)

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
w = torch.randn(
num_experts,
N,
K,
dtype=torch.bfloat16,
device="cuda",
device=device,
)
w_t = w.transpose(-2, -1).requires_grad_(True)

offs = generate_jagged_offs(num_experts, M, multiple_of=128)
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)
x_ref, w_t_ref, offs_ref = (
x.clone().detach().requires_grad_(True),
w_t.clone().detach().requires_grad_(True),
Expand Down Expand Up @@ -328,19 +347,19 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(


@skip_if_rocm("ROCm not supported")
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():
def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic(device):
block_size = 32
M, K, N, num_experts = 4096, 1024, 2048, 8
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
x = torch.randn(M, K, dtype=torch.bfloat16, device=device, requires_grad=True)
w = torch.randn(
num_experts,
N,
K,
dtype=torch.bfloat16,
device="cuda",
device=device,
)
w_t = w.transpose(-2, -1).requires_grad_(True)
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)

x_ref = x.detach().clone().requires_grad_(True)
w_t_ref = w_t.detach().clone().requires_grad_(True)
Expand Down Expand Up @@ -401,19 +420,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_matches_dynamic():


@skip_if_rocm("ROCm not supported")
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():
def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward(device):
block_size = 32
M, K, N, num_experts = 4096, 1024, 2048, 8
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_experts,
N,
K,
dtype=torch.bfloat16,
device="cuda",
device=device,
)
w_t = w.transpose(-2, -1)
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)

x_scale, x_qdata = to_mx(
x.detach(),
Expand Down Expand Up @@ -455,19 +474,19 @@ def test_mxfp8_grouped_gemm_from_qdata_and_scales_forward():


@skip_if_rocm("ROCm not supported")
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp():
def test_mxfp8_grouped_gemm_mxtensor_requires_wgrad_with_hp(device):
block_size = 32
M, K, N, num_experts = 1024, 1024, 2048, 4
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
x = torch.randn(M, K, dtype=torch.bfloat16, device=device)
w = torch.randn(
num_experts,
N,
K,
dtype=torch.bfloat16,
device="cuda",
device=device,
)
w_t = w.transpose(-2, -1)
offs = generate_jagged_offs(num_experts, M, multiple_of=128)
offs = generate_jagged_offs(num_experts, M, multiple_of=128, device=device)

x_scale, x_qdata = to_mx(
x,
Expand Down
Loading