From 5480146bffa37be0982163180dab9afafad9e402 Mon Sep 17 00:00:00 2001 From: Hubert Zhang Date: Thu, 4 Jun 2026 19:45:08 +0000 Subject: [PATCH] feat: support aws efa --- checkpoint_engine/device_utils.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py index 254cf27..35b140c 100644 --- a/checkpoint_engine/device_utils.py +++ b/checkpoint_engine/device_utils.py @@ -4,6 +4,7 @@ import socket import subprocess from functools import lru_cache +from pathlib import Path import torch from loguru import logger @@ -172,6 +173,22 @@ def _resolve_device_specs( return sorted(devices) +def has_efa_pci() -> bool: + """通过 PCI 设备 ID 精确检查是否存在 EFA 硬件""" + pci_path = Path("/sys/class/infiniband/") + if not pci_path.exists(): + return False + for device in pci_path.iterdir(): + try: + vendor = (device / "device" / "vendor").read_text().strip() + # Amazon Vendor ID = 0x1d0f + if vendor == "0x1d0f": + return True + except (OSError, ValueError): # noqa: PERF203 + continue + return False + + class DeviceManager: def __init__(self): self.device_type = self._detect_device_type() @@ -218,14 +235,17 @@ def transfer_engine_protocol(self) -> str: if self.device_type == "npu": return "ascend_direct" elif self.device_type == "cuda": - return "rdma" + if has_efa_pci(): + return "efa" + else: + return "rdma" else: raise TypeError("The current device type is not supported") def rdma_device(self, rank: int) -> str: if self.transfer_engine_protocol == "ascend_direct": return "" - elif self.transfer_engine_protocol == "rdma": + elif self.transfer_engine_protocol in ["rdma", "efa"]: return _get_my_rdma_device(rank, self.device_module.device_count(), _get_rdma_devices()) else: raise TypeError("The current transfer engine protocol is not supported")