Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions checkpoint_engine/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,19 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
if not devices:
raise RuntimeError("no rdma devices found")
try:
assert len(devices) <= gpu_count, (
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
)
assert gpu_count % len(devices) == 0, (
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
)
return devices[local_rank // (gpu_count // len(devices))]
if len(devices) <= gpu_count:
assert gpu_count % len(devices) == 0, (
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
)
return devices[local_rank // (gpu_count // len(devices))]
else:
assert len(devices) % gpu_count == 0, (
f"rdma devices count {len(devices)} should be divisible by gpu count {gpu_count}"
)
device_per_rank = len(devices) // gpu_count
return ",".join(
devices[local_rank * device_per_rank : (local_rank + 1) * device_per_rank]
)
except AssertionError:
logger.error(
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
Expand Down
2 changes: 1 addition & 1 deletion checkpoint_engine/p2p_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, device_manager: DeviceManager):
self.port = self.engine.get_rpc_port()
self.named_tensors: dict[str, torch.Tensor] = {}
logger.info(
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
f"[rank{self.rank}] p2p store initialized, protocol {device_manager.transfer_engine_protocol}, addr {self.addr}, device {self.device}"
)

@property
Expand Down
Loading