diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 752696ac8faa..db5224b1e787 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -327,6 +327,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) + self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None self.inference_params = None @@ -686,6 +687,12 @@ def training_step(self, dataloader_iter, batch_idx): self.allreduce_first_last_embeddings() self.megatron_timer_stop('allreduce_first_last_embeddings') + if self.log_memory_usage: + mem_reserved = torch.cuda.max_memory_reserved() + self.log( + 'peak_memory_usage', mem_reserved, prog_bar=True, rank_zero_only=True, batch_size=1, + ) + ## logging if self.log_train_loss: # When using pipeline parallelism, loss is calculated only in the last pipeline stage and