diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index faae84979..1303e151b 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -542,7 +542,6 @@ def unshard_checkpoint( Note this is not marked abstract because child classes are not required to implemented this. """ - del load_path, local_cache, load_optimizer_state, load_trainer_state, device raise NotImplementedError @contextmanager @@ -1914,13 +1913,13 @@ def unshard_checkpoint( def build_sharded_checkpointer( - cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None + cfg: TrainConfig, *, name: Optional[ShardedCheckpointerType] = None, use_shared_mem_impl: bool = False ) -> Checkpointer: name = name or cfg.sharded_checkpointer if name == ShardedCheckpointerType.torch_new: return TorchNewStyleShardedCheckpointer(cfg) elif name == ShardedCheckpointerType.torch_legacy: - return TorchLegacyShardedCheckpointer(cfg) + return TorchLegacyShardedCheckpointer(cfg, use_shared_mem_impl=use_shared_mem_impl) elif name == ShardedCheckpointerType.local: return LocalShardedCheckpointer(cfg) elif name == ShardedCheckpointerType.olmo_core: diff --git a/scripts/unshard.py b/scripts/unshard.py index 1063e2f4f..41cd83c51 100644 --- a/scripts/unshard.py +++ b/scripts/unshard.py @@ -1,16 +1,11 @@ import logging import shutil from pathlib import Path -from typing import Union +from typing import Optional, Union import torch -from olmo.checkpoint import ( - Checkpointer, - LocalShardedCheckpointer, - OlmoCoreCheckpointer, - TorchLegacyShardedCheckpointer, -) +from olmo.checkpoint import build_sharded_checkpointer from olmo.config import ShardedCheckpointerType, TrainConfig from olmo.safetensors_util import state_dict_to_safetensors_file @@ -20,7 +15,7 @@ def main( input_dir: Union[str, Path], output_dir: Union[str, Path], - sharded_checkpoint_type: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy, + sharded_checkpoint_type: Optional[ShardedCheckpointerType] = None, model_only: bool = False, safe_tensors: bool = False, use_shared_mem_impl: bool = False, @@ -32,15 +27,11 @@ def main( output_dir.mkdir(parents=True, exist_ok=True) config = TrainConfig.load(input_dir / "config.yaml", validate_paths=False) - checkpointer: Checkpointer - if sharded_checkpoint_type == ShardedCheckpointerType.torch_legacy: - checkpointer = TorchLegacyShardedCheckpointer(config, use_shared_mem_impl=use_shared_mem_impl) - elif sharded_checkpoint_type == ShardedCheckpointerType.local: - checkpointer = LocalShardedCheckpointer(config) - elif sharded_checkpoint_type == ShardedCheckpointerType.olmo_core: - checkpointer = OlmoCoreCheckpointer(config) - else: - raise NotImplementedError(sharded_checkpoint_type) + + sharded_checkpoint_type = sharded_checkpoint_type or config.sharded_checkpointer + checkpointer = build_sharded_checkpointer( + config, name=sharded_checkpoint_type, use_shared_mem_impl=use_shared_mem_impl + ) model_state_dict, optim_state_dict, trainer_state_dict = checkpointer.unshard_checkpoint( input_dir, @@ -92,8 +83,8 @@ def main( parser.add_argument( "--type", choices=list(ShardedCheckpointerType), - default=ShardedCheckpointerType.torch_legacy, - help="""The sharded checkpoint type.""", + default=None, + help="""The sharded checkpoint type. Defaults to the sharded checkpoint type set in config.""", ) parser.add_argument( "--model-only",