Skip to content

Commit

Permalink
load index 0 non aggregation checkpoint file, add trainer options val…
Browse files Browse the repository at this point in the history
…idation
  • Loading branch information
baijumeswani committed Dec 17, 2020
1 parent 4bce5f2 commit 7b5307c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
29 changes: 29 additions & 0 deletions orttraining/orttraining/python/training/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,15 @@ def aggregate_checkpoints(paths, pytorch_format=True):
state_dict = {}
sharded_states_original_dims = {}
world_rank = _utils.state_dict_trainer_options_world_rank_key()
mixed_precision = _utils.state_dict_trainer_options_mixed_precision_key()
zero_stage = _utils.state_dict_trainer_options_zero_stage_key()
world_size = _utils.state_dict_trainer_options_world_size_key()
optimizer_name = _utils.state_dict_trainer_options_optimizer_name_key()

loaded_mixed_precision = None
loaded_world_size = None
loaded_zero_stage = None
loaded_optimizer_name = None

for rank, path in enumerate(ordered_paths):
rank_state_dict = _checkpoint_storage.load(path)
Expand All @@ -284,6 +293,26 @@ def aggregate_checkpoints(paths, pytorch_format=True):
assert rank == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \
"Unexpected rank in file at path {}. Expected {}, got {}".\
format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank])
if loaded_mixed_precision is None:
loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision]
else:
assert loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision], \
"Mixed precision state mismatch among checkpoint files. File: {}".format(path)
if loaded_world_size is None:
loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size]
else:
assert loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size], \
"World size state mismatch among checkpoint files. File: {}".format(path)
if loaded_zero_stage is None:
loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage]
else:
assert loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage], \
"Zero stage mismatch among checkpoint files. File: {}".format(path)
if loaded_optimizer_name is None:
loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name]
else:
assert loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name], \
"Optimizer name mismatch among checkpoint files. File: {}".format(path)

# aggregate all model states
_aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict)
Expand Down
7 changes: 3 additions & 4 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,10 +1290,9 @@ def load_checkpoint(self, *paths, strict=True):
# if aggregation is required, aggregation logic must be run on the saved checkpoints
state_dict = checkpoint.aggregate_checkpoints(paths, pytorch_format=False)
else:
# if aggregation is not required, reorder the checkpoints in order of ascending rank,
# and load the checkpoint for the current ORTTrainer rank.
ordered_paths = checkpoint._order_paths(paths)
state_dict = _checkpoint_storage.load(ordered_paths[0])
# if aggregation is not required, there must only be a single file that needs to be loaded
assert len(paths) == 1, "Expected number of files to load: 1, got {}".format(len(paths))
state_dict = _checkpoint_storage.load(paths[0])

# extract user dict from the saved checkpoint
user_dict = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_load_checkpoint(aggregate_checkpoints_mock, load_mock):
}
trainer.load_state_dict = Mock()

load_mock.side_effect = [trainer_options, trainer_options, state_dict]
load_mock.side_effect = [trainer_options, state_dict]
trainer.load_checkpoint('abc')

args_list = load_mock.call_args_list
Expand All @@ -512,9 +512,6 @@ def test_load_checkpoint(aggregate_checkpoints_mock, load_mock):
assert load_kwargs['key'] == 'trainer_options'
load_args, load_kwargs = args_list[1]
assert load_args[0] == 'abc'
assert load_kwargs['key'] == 'trainer_options'
load_args, load_kwargs = args_list[2]
assert load_args[0] == 'abc'
assert 'key' not in load_kwargs
assert not aggregate_checkpoints_mock.called

Expand Down Expand Up @@ -578,7 +575,7 @@ def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock):
}
trainer.load_state_dict = Mock()

load_mock.side_effect = [trainer_options, trainer_options, state_dict]
load_mock.side_effect = [trainer_options, state_dict]
user_dict = trainer.load_checkpoint('abc')

assert torch.all(torch.eq(user_dict['array'], torch.tensor(np.arange(5))))
Expand Down

0 comments on commit 7b5307c

Please sign in to comment.