diff --git a/scripts/train_pytorch.py b/scripts/train_pytorch.py index c7ddd2b595..89226d9e9d 100644 --- a/scripts/train_pytorch.py +++ b/scripts/train_pytorch.py @@ -329,19 +329,23 @@ def train_loop(config: _config.TrainConfig): else: raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume") elif config.overwrite and config.checkpoint_dir.exists(): - shutil.rmtree(config.checkpoint_dir) - logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") + if is_main: + shutil.rmtree(config.checkpoint_dir) + logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}") # Create checkpoint directory with experiment name - if not resuming: + if not resuming and is_main: # For new runs, create experiment-specific checkpoint directory exp_checkpoint_dir = config.checkpoint_dir exp_checkpoint_dir.mkdir(parents=True, exist_ok=True) logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}") - else: + elif resuming: # For resume, checkpoint_dir is already set to the experiment directory logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}") + if use_ddp: + dist.barrier() + # Initialize wandb (only on main process) if is_main: init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)