Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unshard without passing checkpointer type #603

Merged
merged 5 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions olmo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def unshard_checkpoint(

Note this is not marked abstract because child classes are not required to implemented this.
"""
del load_path, local_cache, load_optimizer_state, load_trainer_state, device
raise NotImplementedError

@contextmanager
Expand Down Expand Up @@ -1914,13 +1913,13 @@ def unshard_checkpoint(


def build_sharded_checkpointer(
cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None
cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None, use_shared_mem_impl: bool = False
) -> Checkpointer:
name = name or cfg.sharded_checkpointer
if name == ShardedCheckpointerType.torch_new:
return TorchNewStyleShardedCheckpointer(cfg)
elif name == ShardedCheckpointerType.torch_legacy:
return TorchLegacyShardedCheckpointer(cfg)
return TorchLegacyShardedCheckpointer(cfg, use_shared_mem_impl=use_shared_mem_impl)
elif name == ShardedCheckpointerType.local:
return LocalShardedCheckpointer(cfg)
elif name == ShardedCheckpointerType.olmo_core:
Expand Down
29 changes: 10 additions & 19 deletions scripts/unshard.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import logging
import shutil
from pathlib import Path
from typing import Union
from typing import Optional, Union

import torch

from olmo.checkpoint import (
Checkpointer,
LocalShardedCheckpointer,
OlmoCoreCheckpointer,
TorchLegacyShardedCheckpointer,
)
from olmo.checkpoint import build_sharded_checkpointer
from olmo.config import ShardedCheckpointerType, TrainConfig
from olmo.safetensors_util import state_dict_to_safetensors_file

Expand All @@ -20,7 +15,7 @@
def main(
input_dir: Union[str, Path],
output_dir: Union[str, Path],
sharded_checkpoint_type: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy,
sharded_checkpoint_type: Optional[ShardedCheckpointerType] = None,
model_only: bool = False,
safe_tensors: bool = False,
use_shared_mem_impl: bool = False,
Expand All @@ -32,15 +27,11 @@ def main(
output_dir.mkdir(parents=True, exist_ok=True)

config = TrainConfig.load(input_dir / "config.yaml", validate_paths=False)
checkpointer: Checkpointer
if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy:
checkpointer = TorchLegacyShardedCheckpointer(config, use_shared_mem_impl=use_shared_mem_impl)
elif sharded_checkpoint_type == ShardedCheckpointerType.local:
checkpointer = LocalShardedCheckpointer(config)
elif sharded_checkpoint_type == ShardedCheckpointerType.olmo_core:
checkpointer = OlmoCoreCheckpointer(config)
else:
raise NotImplementedError(sharded_checkpoint_type)

sharded_checkpoint_type = sharded_checkpoint_type or config.sharded_checkpointer
checkpointer = build_sharded_checkpointer(
config, name=sharded_checkpoint_type, use_shared_mem_impl=use_shared_mem_impl
)

model_state_dict, optim_state_dict, trainer_state_dict = checkpointer.unshard_checkpoint(
input_dir,
Expand Down Expand Up @@ -92,8 +83,8 @@ def main(
parser.add_argument(
"--type",
choices=list(ShardedCheckpointerType),
default=ShardedCheckpointerType.torch_legacy,
help="""The sharded checkpoint type.""",
default=None,
help="""The sharded checkpoint type. Defaults to the sharded checkpoint type set in config.""",
)
parser.add_argument(
"--model-only",
Expand Down
Loading