Skip to content

Commit

Permalink
Handle float limit_val_batches (#8426)
Browse files Browse the repository at this point in the history
* Handle float limit_val_batches

Signed-off-by: Abhishree <[email protected]>

* Rectify reconfiguration of float limit_val_batches

Signed-off-by: Abhishree <[email protected]>

* Remove unused imports

Signed-off-by: Abhishree <[email protected]>

* Scale len(val_dataloader) with float limit_val_batches

Signed-off-by: Abhishree <[email protected]>

* Return len(dataloader) in microbatches

Signed-off-by: Abhishree <[email protected]>

* Add back resetting of num val samples

Signed-off-by: Abhishree <[email protected]>

* Fix to ensure float limit_val_batches is multiple of num_micro_batches

Signed-off-by: Abhishree <[email protected]>

* Remove forcing eval samples to 1 for float limit_val_batches

Signed-off-by: Abhishree <[email protected]>

* Fix bug wrt 0 limiot_val_batches

Signed-off-by: Abhishree <[email protected]>

* Add missing mock_dataset line

Signed-off-by: Abhishree <[email protected]>

* Avoid ensuring limit_val_batches is a mutliple of microbatches for 1.0

Signed-off-by: Abhishree <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Restore the hack forcing number of validation and test epochs to 1

Signed-off-by: Jan Baczek <[email protected]>

* Change limit_val_batches to 1.0 for GPT pretraining test. The integer value is covered in other tests

Signed-off-by: Jan Baczek <[email protected]>

---------

Signed-off-by: Abhishree <[email protected]>
Signed-off-by: Jan Baczek <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jan Baczek <[email protected]>
Signed-off-by: ataghibakhsh <[email protected]>
  • Loading branch information
3 people authored and JRD971000 committed Mar 15, 2024
1 parent 1ffab5a commit 698239d
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 22 deletions.
4 changes: 2 additions & 2 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3584,7 +3584,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
trainer.limit_val_batches=2 \
trainer.limit_val_batches=1.0 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=3 \
trainer.precision=16 \
Expand Down Expand Up @@ -3619,7 +3619,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
trainer.limit_val_batches=2 \
trainer.limit_val_batches=1.0 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=6 \
trainer.precision=16 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,12 @@ def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
num_global_batches = num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

Expand Down Expand Up @@ -162,9 +165,12 @@ def __len__(self):
num_available_samples = active_total_samples - self.consumed_samples % active_total_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
num_global_batches = num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
if self.drop_last:
return num_available_samples // self.micro_batch_times_data_parallel_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from nemo.collections.nlp.models.nlp_model import NLPModel
from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION
Expand Down Expand Up @@ -322,9 +323,37 @@ def _reconfigure_val_batches(self):
"""
Reconfigure trainer.limit_val_batches for pretraining
"""
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
if isinstance(self.trainer.limit_val_batches, int):
# Override limit_val_batches to be a multiple of num microbatches and so there are limit_val_batches//num_micro_batches num of global batches
self.trainer.limit_val_batches *= get_num_microbatches()
else:
assert isinstance(self.trainer.limit_val_batches, float)
# Don't reconfigure if limit_val_batches is 0.0
if self.trainer.limit_val_batches == 0.0:
return
# len(self._validation_dl) returns len as num of microbatches
val_len_in_micro_batches = len(self._validation_dl)
if self._validation_ds is not None and len(self._validation_dl) != float("inf"):
if self.trainer.limit_val_batches == 1.0:
self.trainer.limit_val_batches = val_len_in_micro_batches
else:
limit_val_micro_batches = int(val_len_in_micro_batches * self.trainer.limit_val_batches)
if limit_val_micro_batches == 0 and self.trainer.limit_val_batches > 0.0:
min_percentage = 1.0 / len(self._validation_dl)
raise MisconfigurationException(
f"You requested to check {self.trainer.limit_val_batches} of the val_dataloader but"
f" {self.trainer.limit_val_batches} * {len(self._validation_dl)} < 1. Please increase the"
f" `limit_val_batches` argument. Try at least"
f" `limit_val_batches={min_percentage}`"
)
# Make sure trainer.limit_val_batches is a multiple of num of microbatches
if limit_val_micro_batches < get_num_microbatches():
self.trainer.limit_val_batches = get_num_microbatches()
else:
self.trainer.limit_val_batches = (
limit_val_micro_batches - limit_val_micro_batches % get_num_microbatches()
)

# Override num sanity steps to be a multiple of num of microbatches
self.trainer.num_sanity_val_steps *= get_num_microbatches()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1176,32 +1176,24 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor):
return loss

def build_train_valid_test_datasets(self):
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()
logging.info('Building GPT datasets.')
if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float):
raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.")
logging.info('Building GPT datasets.')
global_batch_size = self.cfg.global_batch_size
max_train_steps = self.trainer.max_steps
eval_iters = (max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches
test_iters = self.trainer.limit_test_batches

# Add extra FIM tokens to tokenizer
if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron':
fim_tokens = self.cfg.data.fim.extra_tokens
fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod]
self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens})

train_valid_test_num_samples = [
max_train_steps * global_batch_size,
eval_iters * global_batch_size,
test_iters * global_batch_size,
]

if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float):
train_valid_test_num_samples[
1
] = 1 # This is to make sure we only have one epoch on every validation iteration
# The line below exploits a quirk in mcore dataset construction, to make number of epochs for validation and test equal to 1
# The mcore dataset implementation uses the number N we provide via train_valid_test_num_samples to derive parameter E such that
# E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d)
# Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below).
# Setting N = 1 we force E to be 1 as well
train_valid_test_num_samples = [max_train_steps * global_batch_size, 1, 1]

mock_dataset = self.cfg.data.get("mock_dataset", False)
kwargs = {
Expand Down Expand Up @@ -1329,6 +1321,8 @@ def setup(self, stage=None):
self.setup_training_data(self.cfg.data)
self.setup_validation_data(self.cfg.data)
self.setup_test_data(self.cfg.data)
# Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step
self._reconfigure_val_batches()

if stage == 'fit':
self.initialize_last_rank_embeddings()
Expand Down

0 comments on commit 698239d

Please sign in to comment.