Skip to content

Commit d0da82c

Browse files
committed
Fix rocm CI
Summary: The fix skips the pytorch version check when kernel_preference is EMULATED, since the emulated mode doesn't need scaled_mm_single_dim_strategy or is_pinned_handler from torch.distributed.tensor. The hardware MXFP8 paths (AUTO kernel preference) still get the version check. Test Plan: CI Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent ac0b820 commit d0da82c

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

torchao/prototype/moe_training/conversion_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,18 @@ def _get_tensor_cls_for_config(
3030
)
3131

3232
if isinstance(config, MXFP8TrainingOpConfig):
33-
from torch.distributed.tensor import _dispatch, _ops
33+
from torchao.quantization.quantize_.common import KernelPreference
3434

35-
pytorch_version_supported = hasattr(
36-
_ops, "scaled_mm_single_dim_strategy"
37-
) and hasattr(_dispatch, "is_pinned_handler")
35+
if config.kernel_preference != KernelPreference.EMULATED:
36+
from torch.distributed.tensor import _dispatch, _ops
3837

39-
assert pytorch_version_supported, (
40-
"Please install the latest torch nightly build to use MXFP8 training"
41-
)
38+
pytorch_version_supported = hasattr(
39+
_ops, "scaled_mm_single_dim_strategy"
40+
) and hasattr(_dispatch, "is_pinned_handler")
41+
42+
assert pytorch_version_supported, (
43+
"Please install the latest torch nightly build to use MXFP8 training"
44+
)
4245

4346
return MXFP8TrainingWeightWrapperTensor
4447
elif isinstance(config, Float8TrainingOpConfig):

0 commit comments

Comments
 (0)