Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions test/prototype/moe_training/ep/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,29 @@
import pytest
import torch

from torchao.utils import is_cuda_version_at_least, is_sm_at_least_100
from torchao.utils import is_cuda_version_at_least, is_sm_at_least_100, is_XPU

if not (
_is_xpu = is_XPU()
_is_compatible_cuda = (
torch.cuda.is_available()
and is_sm_at_least_100()
and is_cuda_version_at_least(12, 8)
):
pytest.skip("Test requires CUDA 12.8+ with SM >= 100", allow_module_level=True)
)
if not (_is_xpu or _is_compatible_cuda):
pytest.skip(
"Test requires XPU or CUDA 12.8+ with SM >= 100", allow_module_level=True
)

from torchao.prototype.moe_training.ep.kernels import generate_permute_indices
from torchao.prototype.moe_training.ep.permute import _triton_permute_bwd
from torchao.utils import get_available_devices

_DEVICES = get_available_devices()[1:]


@pytest.fixture(scope="module", params=_DEVICES)
def device(request):
return request.param


@pytest.mark.parametrize(
Expand All @@ -41,10 +53,8 @@
],
)
def test_triton_permute_bwd(
num_tokens, hidden_dim, num_local_experts, ep_degree, alignment
num_tokens, hidden_dim, num_local_experts, ep_degree, alignment, device
):
device = "cuda"

# Generate realistic permutation indices using generate_permute_indices
# Simulate token distribution across experts
tokens_per_expert_group = torch.randint(
Expand Down
24 changes: 18 additions & 6 deletions test/prototype/moe_training/ep/test_permute.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
import pytest
import torch

from torchao.utils import is_cuda_version_at_least, is_sm_at_least_100
from torchao.utils import is_cuda_version_at_least, is_sm_at_least_100, is_XPU

if not (
_is_xpu = is_XPU()
_is_compatible_cuda = (
torch.cuda.is_available()
and is_sm_at_least_100()
and is_cuda_version_at_least(12, 8)
):
pytest.skip("Test requires CUDA 12.8+ with SM >= 100", allow_module_level=True)
)

if not (_is_xpu or _is_compatible_cuda):
pytest.skip(
"Test requires XPU or CUDA 12.8+ with SM >= 100", allow_module_level=True
)

from torchao.prototype.moe_training.ep import permute_mxfp8_fwd_hp_bwd
from torchao.prototype.moe_training.ep.permute import permute_and_pad
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.utils import compute_error
from torchao.utils import get_available_devices

_DEVICES = get_available_devices()[1:]


@pytest.fixture(scope="module", params=_DEVICES)
def device(request):
return request.param


def test_mxfp8_permute_forward():
device = "cuda"
def test_mxfp8_permute_forward(device: str):
tokens = 64
dim = 128
num_experts = 8
Expand Down
Loading