diff --git a/d2go/checkpoint/fsdp_checkpoint.py b/d2go/checkpoint/fsdp_checkpoint.py index e6fce6f1..2014320a 100644 --- a/d2go/checkpoint/fsdp_checkpoint.py +++ b/d2go/checkpoint/fsdp_checkpoint.py @@ -1,6 +1,6 @@ # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. import os -from typing import cast, IO +from typing import Callable, cast, IO import detectron2.utils.comm as comm import torch @@ -15,12 +15,25 @@ ) +def get_max_checkpoint_concurrency() -> int: + return comm.get_world_size() + + # TODO: replace FSDPCheckpointer with central D2GoCheckpointer class FSDPCheckpointer(QATCheckpointer): """ Extend the Checkpointer to support saving/loading FSDP models """ + def __init__( + self, + *args, + concurrency_limit_fetcher: Callable[[], int] = get_max_checkpoint_concurrency, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._concurrency_limit_fetcher: Callable[[], int] = concurrency_limit_fetcher + def is_distributed(self) -> bool: return True @@ -135,8 +148,10 @@ def save(self, name: str, tag_last_ckpt=True, **kwargs) -> None: basename = "rank{}.pth".format(comm.get_rank()) save_file = os.path.join(new_save_dir, basename) assert os.path.basename(save_file) == basename, basename - # allow 8 GPUs to write to manifold at the same time - with interleave_by_rank(concurrency_limit=8): + # Limit the write concurrency to avoid QPS overload + with interleave_by_rank( + concurrency_limit=self._concurrency_limit_fetcher() + ): self._save_file(data, save_file) # Main process tags last checkpoint if no errors in all processes comm.synchronize() @@ -156,8 +171,8 @@ def _save_file(self, data, filename): torch.save(data, cast(IO[bytes], f)) def _load_file(self, f: str): - # allow 8 GPUs to read from manifold at the same time - with interleave_by_rank(concurrency_limit=8): + # Limit the read concurrency to avoid QPS overload + with interleave_by_rank(concurrency_limit=self._concurrency_limit_fetcher()): return super()._load_file(f)