From 0a6620c3cbb93062e4d48fb544b6f6408f2630de Mon Sep 17 00:00:00 2001 From: Hubert Zhang Date: Thu, 4 Jun 2026 19:45:54 +0000 Subject: [PATCH] feat: support multiple interface for single device --- checkpoint_engine/device_utils.py | 20 +++++++++++++------- checkpoint_engine/p2p_store.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/checkpoint_engine/device_utils.py b/checkpoint_engine/device_utils.py index 254cf27..428b1d3 100644 --- a/checkpoint_engine/device_utils.py +++ b/checkpoint_engine/device_utils.py @@ -87,13 +87,19 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> if not devices: raise RuntimeError("no rdma devices found") try: - assert len(devices) <= gpu_count, ( - f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}" - ) - assert gpu_count % len(devices) == 0, ( - f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" - ) - return devices[local_rank // (gpu_count // len(devices))] + if len(devices) <= gpu_count: + assert gpu_count % len(devices) == 0, ( + f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}" + ) + return devices[local_rank // (gpu_count // len(devices))] + else: + assert len(devices) % gpu_count == 0, ( + f"rdma devices count {len(devices)} should be divisible by gpu count {gpu_count}" + ) + device_per_rank = len(devices) // gpu_count + return ",".join( + devices[local_rank * device_per_rank : (local_rank + 1) * device_per_rank] + ) except AssertionError: logger.error( "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices." diff --git a/checkpoint_engine/p2p_store.py b/checkpoint_engine/p2p_store.py index d217a72..218e1b6 100644 --- a/checkpoint_engine/p2p_store.py +++ b/checkpoint_engine/p2p_store.py @@ -41,7 +41,7 @@ def __init__(self, device_manager: DeviceManager): self.port = self.engine.get_rpc_port() self.named_tensors: dict[str, torch.Tensor] = {} logger.info( - f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}" + f"[rank{self.rank}] p2p store initialized, protocol {device_manager.transfer_engine_protocol}, addr {self.addr}, device {self.device}" ) @property