Skip to content

Commit

Permalink
Make olmo-core checkpointer more robust on weka (#624)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Jun 17, 2024
1 parent ddc8847 commit 2417b11
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
24 changes: 21 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,12 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)

# In the cases where we're using a shared NFS drive between ranks to save checkpoints,
# creating the temp directory from rank 0 might not be immediately
# realized in the file systems of the other ranks.
# So we wait here across all ranks until that tmp checkpoint directory is visible.
wait_for(lambda: checkpoint_dir_tmp.exists(), "Waiting for checkpoint directory", timeout=10.0)

barrier()

# Yield temporary directory for `.save_checkpoint()` to use.
Expand Down Expand Up @@ -1914,9 +1920,22 @@ def save_checkpoint(

with self._temporary_wd(dir) as checkpoint_dir:
log.info("Saving model and optim state...")
local_files_created = save_model_and_optim_state(
checkpoint_dir, dist_model, optim, save_overwrite=self.cfg.save_overwrite
if get_fs_local_rank() == 0:
(checkpoint_dir / "model").mkdir(exist_ok=True, parents=True)
(checkpoint_dir / "optim").mkdir(exist_ok=True, parents=True)
(checkpoint_dir / "train").mkdir(exist_ok=True, parents=True)

wait_for(
lambda: (checkpoint_dir / "model").exists(), "Waiting for checkpoint model directory", timeout=10.0
)
wait_for(
lambda: (checkpoint_dir / "optim").exists(), "Waiting for checkpoint optim directory", timeout=10.0
)
wait_for(
lambda: (checkpoint_dir / "train").exists(), "Waiting for checkpoint train directory", timeout=10.0
)

local_files_created = save_model_and_optim_state(checkpoint_dir, dist_model, optim)
if upload_to is not None:
for path in local_files_created:
path = Path(path)
Expand All @@ -1929,7 +1948,6 @@ def save_checkpoint(
checkpoint_dir,
f"train/rank{get_global_rank()}.pt",
trainer_state,
save_overwrite=self.cfg.save_overwrite,
upload_to=upload_to,
)

Expand Down
5 changes: 4 additions & 1 deletion olmo/torch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ def get_fs_local_rank() -> int:
if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`,
but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`.
"""
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())
if os.environ.get("OLMO_SHARED_FS"):
return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank())
else:
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank())


def move_to_device(o: T, device: torch.device) -> T:
Expand Down

0 comments on commit 2417b11

Please sign in to comment.