From 07f9779ff47d0af960cfd54c8e9819f31a740ab4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:28:49 -0700 Subject: [PATCH] Enable megatron core loggers for GPT pretraining (#8354) (#8384) * Logging changes tested for gpt_pretraining * Additional args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Aishwarya Bhandare Co-authored-by: ashbhandare Co-authored-by: Aishwarya Bhandare Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Harper Signed-off-by: ataghibakhsh --- .../conf/megatron_gpt_config.yaml | 7 ++++ .../conf/megatron_model_base_config.yaml | 3 +- .../language_modeling/megatron_base_model.py | 40 ++++++++++++++++++- .../language_modeling/megatron_gpt_model.py | 6 +++ 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index 63d2297838c3..004e8b584a13 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -211,6 +211,13 @@ model: ## Network sharp: False # Enable the use of SHARP for NCCL data-parallel communications. This is going to be ignored if the network doesn't support SHARP. + ## Megatron timers + enable_megatron_timers: False + megatron_timer_kwargs: + log_every_n_steps: 10 + log_mode: minmax + barrier: False + data: # Path to data must be specified by the user. # Supports List, String and Dictionary diff --git a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml index 4da8177685a1..235bf3d3f227 100644 --- a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml @@ -37,4 +37,5 @@ normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sq num_moe_experts: 1 # When >1, FFNs are changed to MoE layers moe_frequency: 1 # every Nth ffn layer will be made MoE moe_dropout: 0.0 # Dropout value for MoE layers -use_flash_attention: false # Use flash attention in self-attention module \ No newline at end of file +use_flash_attention: false # Use flash attention in self-attention module +enable_megatron_timers: false # Megatron timers \ No newline at end of file diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 269279d8e856..5321a307b2c4 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -69,6 +69,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core import Timers + + HAVE_MEGATRON_CORE_TIMERS = True +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE_TIMERS = False + __all__ = ["MegatronBaseModel"] @@ -124,6 +131,17 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): else torch.float32 ) + self.megatron_timers = None + if self.cfg.get('enable_megatron_timers', False) and HAVE_MEGATRON_CORE_TIMERS: + self.megatron_timers_cfg = dict(self.cfg.get('megatron_timer_kwargs', dict())) + if 'log_every_n_steps' not in self.megatron_timers_cfg: + self.megatron_timers_cfg['log_every_n_steps'] = self.trainer.log_every_n_steps + if 'log_option' not in self.megatron_timers_cfg: + self.megatron_timers_cfg['log_option'] = 'minmax' # minmax, max, all + if 'barrier' not in self.megatron_timers_cfg: + self.megatron_timers_cfg['barrier'] = False + self.megatron_timers = Timers(log_level=2, log_option=self.megatron_timers_cfg['log_option']) + # set the megatron core model parallel config self.model_parallel_config: ModelParallelConfig = self.build_model_parallel_config() @@ -615,6 +633,13 @@ def sync_overlap_parameters(self, params=None): def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None: super().on_train_batch_end(outputs, dataloader_iter, batch_idx) + # Megatron Timers + if self.megatron_timers: + if self.global_step % self.megatron_timers_cfg["log_every_n_steps"] == 0: + logging.info( + "\n " + self.megatron_timers.get_all_timers_string(barrier=self.megatron_timers_cfg["barrier"]) + ) + # TODO: Replace with newer override for scheduler.step() instead of # search for plugins for fp16 GradScalar if self.trainer.precision_plugin is not None and isinstance( @@ -1044,7 +1069,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig: and megatron_amp_O2, # NeMo does not currently support fp16 training with megatron amp O2, eval and inference is supported "bf16": self.torch_dtype == torch.bfloat16 and megatron_amp_O2, "params_dtype": self.params_dtype, - "timers": None, # NeMo does not currently support megatron core timers + "timers": self.megatron_timers, "async_tensor_model_parallel_allreduce": self.cfg.get('tensor_model_parallel_world_size', 1) > 1 and not self.cfg.get('sequence_parallel', False), "pipeline_dtype": pipeline_dtype, @@ -1157,3 +1182,16 @@ def configure_sharded_model(self): # Move the CPU-initialized model (with `use_cpu_initialization=True`) to GPU, which is to avoid # out-of-memory carash before sharding. In case of GPU-initialized model, this is no-op. self.model = self.model.cuda(torch.cuda.current_device()) + + def megatron_timer_start(self, name, log_level): + if self.megatron_timers: + self.megatron_timers(name, log_level).start(barrier=False) + + def megatron_timer_stop(self, name): + if self.megatron_timers: + self.megatron_timers(name).stop() + + def optimizer_step(self, *args, **kwargs): + self.megatron_timer_start('optimizer', log_level=1) + super().optimizer_step(*args, **kwargs) + self.megatron_timer_stop('optimizer') diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9c3657d4c4ef..2770090a7c1e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -654,8 +654,11 @@ def training_step(self, dataloader_iter, batch_idx): # when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False): + self.megatron_timer_start('allreduce_sequence_parallel_gradients', log_level=1) self.allreduce_sequence_parallel_gradients() + self.megatron_timer_stop('allreduce_sequence_parallel_gradients') + self.megatron_timer_start('gradient_allreduce', log_level=1) if self.use_fsdp: # Reduce the gradients omitted from FSDP-sharding self.allreduce_fsdp_sharding_omitted_gradients() @@ -673,12 +676,15 @@ def training_step(self, dataloader_iter, batch_idx): # async grad allreduce is not currently implemented for O1/autocasting mixed precision training # so we all-reduce gradients after the pipeline self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf) + self.megatron_timer_stop('gradient_allreduce') if self.cfg.get('pipeline_model_parallel_size', 1) > 1 and self.cfg.get( 'share_embeddings_and_output_weights', True ): + self.megatron_timer_start('allreduce_first_last_embeddings', log_level=1) # when using pipeline parallelism the first and last stage must keep embeddings in sync self.allreduce_first_last_embeddings() + self.megatron_timer_stop('allreduce_first_last_embeddings') ## logging if self.log_train_loss: