Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions checkpoint_engine/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import socket
import subprocess
from functools import lru_cache
from pathlib import Path

import torch
from loguru import logger
Expand Down Expand Up @@ -172,6 +173,22 @@ def _resolve_device_specs(
return sorted(devices)


def has_efa_pci() -> bool:
"""通过 PCI 设备 ID 精确检查是否存在 EFA 硬件"""
pci_path = Path("/sys/class/infiniband/")
if not pci_path.exists():
return False
for device in pci_path.iterdir():
try:
vendor = (device / "device" / "vendor").read_text().strip()
# Amazon Vendor ID = 0x1d0f
if vendor == "0x1d0f":
return True
except (OSError, ValueError): # noqa: PERF203
continue
return False


class DeviceManager:
def __init__(self):
self.device_type = self._detect_device_type()
Expand Down Expand Up @@ -218,14 +235,17 @@ def transfer_engine_protocol(self) -> str:
if self.device_type == "npu":
return "ascend_direct"
elif self.device_type == "cuda":
return "rdma"
if has_efa_pci():
return "efa"
else:
return "rdma"
else:
raise TypeError("The current device type is not supported")

def rdma_device(self, rank: int) -> str:
if self.transfer_engine_protocol == "ascend_direct":
return ""
elif self.transfer_engine_protocol == "rdma":
elif self.transfer_engine_protocol in ["rdma", "efa"]:
return _get_my_rdma_device(rank, self.device_module.device_count(), _get_rdma_devices())
else:
raise TypeError("The current transfer engine protocol is not supported")
Loading