Skip to content

Commit

Permalink
add first_val_step to mcore scheduler (#8150)
Browse files Browse the repository at this point in the history
* add first_val_step for mcore schedules

Signed-off-by: jiemingz <[email protected]>

* fix if non fp8

Signed-off-by: jiemingz <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add mcore installation to ci

Signed-off-by: jiemingz <[email protected]>

* Update Jenkinsfile

Signed-off-by: JimmyZhang12 <[email protected]>

* Fix SFT missing arg

Signed-off-by: jiemingz <[email protected]>

---------

Signed-off-by: jiemingz <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Co-authored-by: jiemingz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Jan 25, 2024
1 parent 6143f6b commit f25be00
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
18 changes: 9 additions & 9 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ pipeline {

// pip package should be working with main, if not we can update the commit here
// until the pip package is updated
// stage('Megatron Core installation') {
// steps {
// sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
// cd Megatron-LM && \
// git checkout 973330e9c3681604703bf1eb6b5a265d1b9b9b38 && \
// pip install .'
// }
// }
stage('Megatron Core installation') {
steps {
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout bed60a881f4b238b1c14b6c6a64997cc636e77b6 && \
pip install .'
}
}

stage('PyTorch Lightning version') {
steps {
Expand Down Expand Up @@ -5268,4 +5268,4 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
cleanWs()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.mcore_gpt = cfg.get('mcore_gpt', False)
self.spec_name = cfg.get('name', '')
if cfg.get('fp8', False):
self.prev_step_training = True

self.rampup_batch_size = self.cfg.get('rampup_batch_size', None)
if self.rampup_batch_size:
Expand Down Expand Up @@ -498,7 +500,7 @@ def forward(self, tokens, text_position_ids, attention_mask, labels):
output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels)
return output_tensor

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only, first_val_step=None):

# handle asynchronous grad reduction
no_sync_func = None
Expand Down Expand Up @@ -528,6 +530,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=self.cfg.encoder_seq_length,
micro_batch_size=self.cfg.micro_batch_size,
first_val_step=first_val_step,
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -622,6 +625,9 @@ def training_step(self, dataloader_iter, batch_idx):

loss_mean = self.fwd_bwd_step(dataloader_iter, batch_idx, False)

if self.cfg.get('fp8', False):
self.prev_step_training = self.training

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
self.allreduce_sequence_parallel_gradients()
Expand Down Expand Up @@ -1047,7 +1053,13 @@ def validation_step(self, dataloader_iter, batch_idx):
for model_module in self.model:
model_module.eval()

loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True)
if self.cfg.get('fp8', False):
first_val_step = self.prev_step_training and not self.training
self.prev_step_training = self.training
else:
first_val_step = None

loss = self.fwd_bwd_step(dataloader_iter, batch_idx, True, first_val_step)

if isinstance(self.model, list):
for model_module in self.model:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _determine_log_key(self, data_config, dataloader_idx, metric_name, mode):
else:
return base_key + f"dataloader{dataloader_idx}"

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only, first_val_step=None):
batch = next(dataloader_iter)

log_token_counts = self.cfg.get('log_token_counts', False)
Expand Down Expand Up @@ -360,6 +360,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
forward_only=forward_only,
seq_length=seq_length,
micro_batch_size=get_micro_batch_size(),
first_val_step=first_val_step,
)

# only the last stages of the pipeline return losses
Expand Down

0 comments on commit f25be00

Please sign in to comment.