Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checkpoint resuming compatible for 2403 container #9199

Merged
merged 11 commits into from
May 17, 2024
68 changes: 68 additions & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
try:
from megatron.core import dist_checkpointing, parallel_state
from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace
from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject
from megatron.core.dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
Expand Down Expand Up @@ -415,6 +416,70 @@ def _fix_device(t):

return dict_list_map_outplace(_fix_device, ckpt)

def _get_param_group(self, state_dict: Dict[str, Any]):
"""Return the param groups in the state dict"""
return (
state_dict['optimizer_states'][0]['param_groups']
if 'optimizer' not in state_dict['optimizer_states'][0]
else state_dict['optimizer_states'][0]['optimizer']['param_groups']
)

def _check_param_groups_mismatch(self, checkpoint_path: Union[str, Path], sharded_state_dict: Dict[str, Any]):
"""
Check if the number of param groups in the checkpoint not match with the sharded_state_dict
Returns:
bool: True if the number of param groups does not match
"""
common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path)
model_param_groups = self._get_param_group(common_state_dict)
checkpoint_param_groups = self._get_param_group(sharded_state_dict)
return len(model_param_groups) != len(checkpoint_param_groups)

def _fix_param_groups(
self, checkpoint_path: Union[str, Path], sharded_state_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Try to fix the param groups in the checkpoint.
This is to fix the bug that in 24.03, all checkpoints store EP param group regardless of using EP or not.
This function makes sure all checkpoints are compatible for loading.
Returns:
sharded_state_dict: Loaded dictionary for the distributed load function
"""
common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path)
model_param_groups = self._get_param_group(sharded_state_dict)
checkpoint_param_groups = self._get_param_group(common_state_dict)

model_has_expert_param = any(param.get('is_expert', False) for param in model_param_groups)
checkpoint_has_expert_param = any(param.get('is_expert', False) for param in checkpoint_param_groups)

expert_index = None
if checkpoint_has_expert_param and not model_has_expert_param:
logging.warning(
'Currently training the model without expert parallelism while restored checkpoint has EP params. Ignoring the EP params for restoring.'
)
expert_index = next(
(index for index, entry in enumerate(checkpoint_param_groups) if entry.get('is_expert', False)),
None,
)
if expert_index:
# Temporary empty params so that loading doesn't fail
model_param_groups.insert(expert_index, {'params': LocalNonpersitentObject([]), 'is_expert': True})
if 'optimizer' in sharded_state_dict['optimizer_states'][0]:
sharded_state_dict['optimizer_states'][0]['optimizer']['param_groups'] = model_param_groups
else:
sharded_state_dict['optimizer_states'][0]['param_groups'] = model_param_groups
else:
raise ValueError('Cannot find expert param in the checkpoint.')

loaded_state_dict = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict)
if expert_index is not None:
# Remove the temporary empty params added above
if 'optimizer' in loaded_state_dict['optimizer_states'][0]:
loaded_state_dict['optimizer_states'][0]['optimizer']['param_groups'].pop(expert_index)
else:
loaded_state_dict['optimizer_states'][0]['param_groups'].pop(expert_index)
return loaded_state_dict

def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
"""PTL method which we override to integrate distributed checkpoints for model parallel models.
In order to load distributed checkpoints we need to provide the sharded_state_dict to
Expand All @@ -437,6 +502,9 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
# after dist_checkpointing.load, sharded tensors will be replaced with tensors
checkpoint['state_dict'] = sharded_state_dict
checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()]

if self._check_param_groups_mismatch(checkpoint_path, checkpoint):
return self._fix_param_groups(checkpoint_path, checkpoint)
return self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=checkpoint)

# Legacy model parallel checkpointing logic, does not use megatron core
Expand Down
Loading