Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions d2go/checkpoint/fsdp_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import os
from typing import cast, IO
from typing import Callable, cast, IO

import detectron2.utils.comm as comm
import torch
Expand All @@ -15,12 +15,25 @@
)


def get_max_checkpoint_concurrency() -> int:
return comm.get_world_size()


# TODO: replace FSDPCheckpointer with central D2GoCheckpointer
class FSDPCheckpointer(QATCheckpointer):
"""
Extend the Checkpointer to support saving/loading FSDP models
"""

def __init__(
self,
*args,
concurrency_limit_fetcher: Callable[[], int] = get_max_checkpoint_concurrency,
**kwargs,
):
super().__init__(*args, **kwargs)
self._concurrency_limit_fetcher: Callable[[], int] = concurrency_limit_fetcher

def is_distributed(self) -> bool:
return True

Expand Down Expand Up @@ -135,8 +148,10 @@ def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None:
basename = "rank{}.pth".format(comm.get_rank())
save_file = os.path.join(new_save_dir, basename)
assert os.path.basename(save_file) == basename, basename
# allow 8 GPUs to write to manifold at the same time
with interleave_by_rank(concurrency_limit=8):
# Limit the write concurrency to avoid QPS overload
with interleave_by_rank(
concurrency_limit=self._concurrency_limit_fetcher()
):
self._save_file(data, save_file)
# Main process tags last checkpoint if no errors in all processes
comm.synchronize()
Expand All @@ -156,8 +171,8 @@ def _save_file(self, data, filename):
torch.save(data, cast(IO[bytes], f))

def _load_file(self, f: str):
# allow 8 GPUs to read from manifold at the same time
with interleave_by_rank(concurrency_limit=8):
# Limit the read concurrency to avoid QPS overload
with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()):
return super()._load_file(f)


Expand Down