Skip to content

Commit

Permalink
fix bug with saving unsharded checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 11, 2024
1 parent 3e3df71 commit 905359e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
checkpoint_dir_tmp = checkpoint_dir.with_name(checkpoint_dir.name + "-tmp")
if get_fs_local_rank() == 0:
shutil.rmtree(checkpoint_dir_tmp, ignore_errors=True)
checkpoint_dir_tmp.mkdir(exist_ok=True, parents=True)

barrier()

# Yield temporary directory for `.save_checkpoint()` to use.
Expand All @@ -502,10 +504,8 @@ def _temporary_wd(self, dir: PathOrStr) -> Generator[Path, None, None]:
barrier()

# Finally if all went well replace the temporary directory with the actual
# checkpoint directory. Note that for some checkpointers the local rank 0 might
# not use this folder, so it may not exist; FullCheckpointer, for example, only creates
# this for global rank 0.
if get_fs_local_rank() == 0 and checkpoint_dir_tmp.exists():
# checkpoint directory.
if get_fs_local_rank() == 0:
# Replace temp directory with target checkpoint directory.
try:
checkpoint_dir_tmp.replace(checkpoint_dir)
Expand Down

0 comments on commit 905359e

Please sign in to comment.