From 1cfecc97a3417e06c339ef3f4fc08c09c318d465 Mon Sep 17 00:00:00 2001 From: Nithin Rao Date: Sun, 10 Nov 2024 15:01:56 -0500 Subject: [PATCH 01/24] Timestamps to transcribe (#10950) * inital version Signed-off-by: Nithin Rao Koluguri * Support for RNNT, TDT, Hybrid Models Signed-off-by: Nithin Rao Koluguri * move change of decoder stratery from mixin to individual model class Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok * update transcribe_speech.py Signed-off-by: Nithin Rao Koluguri * uncomment Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok * add docs Signed-off-by: Nithin Rao Koluguri * fix docs Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok * codeql fixes Signed-off-by: Nithin Rao Koluguri * unit tests Signed-off-by: Nithin Rao Koluguri * minor rebase fix Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok * add None case to restore the state set outside using decoding_stratergy() Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok * remove ipdb traces Signed-off-by: Nithin Rao Koluguri * updates doc for transcription.py Signed-off-by: Nithin Rao Koluguri * remove preserve alignment for AED models as it doesn;t support it Signed-off-by: Nithin Rao Koluguri * lint warnings Signed-off-by: Nithin Rao Koluguri * Apply isort and black reformatting Signed-off-by: nithinraok --------- Signed-off-by: Nithin Rao Koluguri Signed-off-by: nithinraok Co-authored-by: Nithin Rao Koluguri Co-authored-by: nithinraok --- docs/source/asr/intro.rst | 35 ++++- .../aed/speech_to_text_aed_chunked_infer.py | 22 ++- .../ctc/speech_to_text_buffered_infer_ctc.py | 13 +- .../speech_to_text_buffered_infer_rnnt.py | 12 +- .../speech_translation/translate_speech.py | 12 +- examples/asr/transcribe_speech.py | 55 ++++---- .../asr/models/aed_multitask_models.py | 52 +++++-- nemo/collections/asr/models/ctc_bpe_models.py | 21 +-- nemo/collections/asr/models/ctc_models.py | 62 ++++++-- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 26 ++-- .../asr/models/hybrid_rnnt_ctc_models.py | 66 +++++++-- .../collections/asr/models/rnnt_bpe_models.py | 21 +-- nemo/collections/asr/models/rnnt_models.py | 59 ++++++-- nemo/collections/asr/modules/conv_asr.py | 19 ++- .../asr/parts/mixins/transcription.py | 12 +- .../asr/parts/submodules/rnnt_decoding.py | 132 ++++++++++-------- .../asr/parts/utils/streaming_utils.py | 38 ++--- .../asr/parts/utils/transcribe_utils.py | 95 +++++++++++-- tests/collections/asr/conftest.py | 17 +++ .../asr/mixins/test_transcription.py | 107 +++++++++----- 20 files changed, 623 insertions(+), 253 deletions(-) diff --git a/docs/source/asr/intro.rst b/docs/source/asr/intro.rst index aae372765a8a..ade767e541a0 100644 --- a/docs/source/asr/intro.rst +++ b/docs/source/asr/intro.rst @@ -16,10 +16,39 @@ After :ref:`installing NeMo`, you can transcribe an audio file as asr_model = nemo_asr.models.ASRModel.from_pretrained("stt_en_fastconformer_transducer_large") transcript = asr_model.transcribe(["path/to/audio_file.wav"]) -Obtain word/segment timestamps -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Obtain timestamps +^^^^^^^^^^^^^^^^^ -You can also obtain timestamps for each word or segment in the transcription as follows: +Obtaining char(token), word or segment timestamps is also possible with NeMo ASR Models. + +Currently, timestamps are available for Parakeet Models with all types of decoders (CTC/RNNT/TDT). Support for AED models would be added soon. + +There are two ways to obtain timestamps: +1. By using the `timestamps=True` flag in the `transcribe` method. +2. For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps. + +With the `timestamps=True` flag, you can obtain timestamps for each character in the transcription as follows: + +.. code-block:: python + + # import nemo_asr and instantiate asr_model as above + import nemo.collections.asr as nemo_asr + asr_model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt_ctc-110m") + + # specify flag `timestamps=True` + hypotheses = asr_model.transcribe(["path/to/audio_file.wav"], timestamps=True) + + # by default, timestamps are enabled for char, word and segment level + word_timestamps = hypotheses[0][0].timestep['word'] # word level timestamps for first sample + segment_timestamps = hypotheses[0][0].timestep['segment'] # segment level timestamps + char_timestamps = hypotheses[0][0].timestep['char'] # char level timestamps + + for stamp in segment_timestamps: + print(f"{stamp['start']}s - {stamp['end']}s : {stamp['segment']}") + + # segment level timestamps (if model supports Punctuation and Capitalization, segment level timestamps are displayed based on punctuation otherwise complete transcription is considered as a single segment) + +For more control over the timestamps, you can update the decoding config to mention type of timestamps (char, word, segment) and also specify the segment seperators or word seperator for segment and word level timestamps as follows: .. code-block:: python diff --git a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py index 0417522885b9..8188bcced14d 100644 --- a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py +++ b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py @@ -13,11 +13,13 @@ # limitations under the License. """ -This script chunks long audios into non-overlapping segments of `chunk_len_in_secs` seconds and performs inference on each +This script chunks long audios into non-overlapping segments of `chunk_len_in_secs` +seconds and performs inference on each segment individually. The results are then concatenated to form the final output. Below is an example of how to run this script with the Canary-1b model. -It's recommended to use manifest input, otherwise the model will perform English ASR with punctuations and capitalizations. +It's recommended to use manifest input, otherwise the model will perform English ASR +with punctuations and capitalizations. An example manifest line: { "audio_filepath": "/path/to/audio.wav", # path to the audio file @@ -41,11 +43,10 @@ """ -import contextlib import copy import glob import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -67,6 +68,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription config + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -116,6 +121,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -160,7 +169,8 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: if model_cfg.preprocessor.normalize != "per_feature": logging.error( - "Only EncDecMultiTaskModel models trained with per_feature normalization are supported currently" + "Only EncDecMultiTaskModel models trained with per_feature normalization are supported \ + currently" ) # Disable config overwriting @@ -206,7 +216,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py index 77b97e0ab516..87370d278f98 100644 --- a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py +++ b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py @@ -35,12 +35,11 @@ You can use `DEBUG=1 python speech_to_text_buffered_infer_ctc.py ...` to print out the predictions of the model, and ground-truth text if presents in manifest. """ -import contextlib import copy import glob import math import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -65,6 +64,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription Configuration for buffered inference. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -114,6 +117,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -221,7 +228,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: filepaths, ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index 501ca525c1ed..e6e84cdfa6c4 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -61,7 +61,7 @@ import glob import math import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass from typing import Optional import pytorch_lightning as pl @@ -87,6 +87,10 @@ @dataclass class TranscriptionConfig: + """ + Transcription Configuration for buffered inference. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -143,6 +147,10 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> TranscriptionConfig: + """ + Transcribes the input audio and can be used to infer long audio files by chunking + them into smaller segments. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') torch.set_grad_enabled(False) @@ -274,7 +282,7 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: ) output_filename, pred_text_attr_name = write_transcription( - hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, compute_timestamps=False + hyps, cfg, model_name, filepaths=filepaths, compute_langs=False, timestamps=False ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 47717f562774..53599e1b3511 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import json import os from dataclasses import dataclass, is_dataclass @@ -65,13 +64,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = ConformerChangeConfig() @dataclass class TranslationConfig: + """ + Translation Configuration for audio to text translation. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -106,6 +111,9 @@ class TranslationConfig: @hydra_runner(config_name="TranslationConfig", schema=TranslationConfig) def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: + """ + Main function to translate audio to text using a pretrained/finetuned model. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index e529c988779a..a543fcf5e252 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import json import os import time @@ -48,14 +47,9 @@ model_path: path to .nemo ASR checkpoint pretrained_name: name of pretrained ASR model (from NGC registry) audio_dir: path to directory with audio files - dataset_manifest: path to dataset JSON manifest file (in NeMo format) - - compute_timestamps: Bool to request greedy time stamp information (if the model supports it) + dataset_manifest: path to dataset JSON manifest file (in NeMo formats compute_langs: Bool to request language ID information (if the model supports it) - - (Optionally: You can limit the type of timestamp computations using below overrides) - ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word]) - rnnt_decoding.rnnt_timestamp_type="all" # (default all, can be [all, char, word]) + timestamps: Bool to request greedy time stamp information (if the model supports it) by default None (Optionally: You can limit the type of timestamp computations using below overrides) ctc_decoding.ctc_timestamp_type="all" # (default all, can be [all, char, word, segment]) @@ -98,7 +92,7 @@ clean_groundtruth_text=True \ langid='en' \ batch_size=32 \ - compute_timestamps=False \ + timestamps=False \ compute_langs=False \ cuda=0 \ amp=True \ @@ -109,13 +103,19 @@ @dataclass class ModelChangeConfig: + """ + Sub-config for changes specific to the Conformer Encoder + """ - # Sub-config for changes specific to the Conformer Encoder conformer: ConformerChangeConfig = field(default_factory=ConformerChangeConfig) @dataclass class TranscriptionConfig: + """ + Transcription Configuration for audio to text transcription. + """ + # Required configs model_path: Optional[str] = None # Path to a .nemo file pretrained_name: Optional[str] = None # Name of a pretrained model @@ -136,10 +136,11 @@ class TranscriptionConfig: pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. random_seed: Optional[int] = None # seed number going to be used in seed_everything() - # Set to True to output greedy timestamp information (only supported models) - compute_timestamps: bool = False - # set to True if need to return full alignment information - preserve_alignment: bool = False + # Set to True to output greedy timestamp information (only supported models) and returns full alignment hypotheses + timestamps: Optional[bool] = None + + # Set to True to return hypotheses instead of text from the transcribe function + return_hypotheses: bool = False # Set to True to output language ID information compute_langs: bool = False @@ -171,7 +172,8 @@ class TranscriptionConfig: # Implicit single-turn assuming default role='user' (works with Canary-1B) # +prompt.source_lang=en +prompt.target_lang=es +prompt.task=asr +prompt.pnc=yes # Explicit single-turn prompt: - # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes + # +prompt.role=user +prompt.slots.source_lang=en +prompt.slots.target_lang=es + # +prompt.slots.task=s2t_translation +prompt.slots.pnc=yes # Explicit multi-turn prompt: # +prompt.turns='[{role:user,slots:{source_lang:en,target_lang:es,task:asr,pnc:yes}}]' prompt: dict = field(default_factory=dict) @@ -194,9 +196,6 @@ class TranscriptionConfig: # if True, will also skip writing anything to the output file return_transcriptions: bool = False - # Set to False to return text instead of hypotheses from the transcribe function, so as to save memory - return_hypotheses: bool = True - # key for groundtruth text in manifest gt_text_attr_name: str = "text" gt_lang_attr_name: str = "lang" @@ -208,6 +207,9 @@ class TranscriptionConfig: @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis]]: + """ + Transcribes the input audio and can be used to infer with Encoder-Decoder models. + """ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') for key in cfg: @@ -272,10 +274,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis asr_model.to(getattr(torch, cfg.compute_dtype)) # we will adjust this flag if the model does not support it - compute_timestamps = cfg.compute_timestamps compute_langs = cfg.compute_langs - # has to be True if timestamps are required - preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): @@ -295,7 +294,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): if isinstance(asr_model.decoding, MultiTaskDecoding): cfg.multitask_decoding.compute_langs = cfg.compute_langs - cfg.multitask_decoding.preserve_alignments = cfg.preserve_alignment if cfg.extract_nbest: cfg.multitask_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True @@ -309,9 +307,6 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis if cfg.extract_nbest: decoding_cfg.beam.return_best_hypothesis = False cfg.return_hypotheses = True - decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it - if 'preserve_alignments' in decoding_cfg: - decoding_cfg.preserve_alignments = preserve_alignment if 'compute_langs' in decoding_cfg: decoding_cfg.compute_langs = cfg.compute_langs if hasattr(asr_model, 'cur_decoder'): @@ -325,16 +320,12 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis cfg.rnnt_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True cfg.rnnt_decoding.fused_batch_size = -1 - cfg.rnnt_decoding.compute_timestamps = cfg.compute_timestamps cfg.rnnt_decoding.compute_langs = cfg.compute_langs - if 'preserve_alignments' in cfg.rnnt_decoding: - cfg.rnnt_decoding.preserve_alignments = preserve_alignment asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: if cfg.compute_langs: raise ValueError("CTC models do not support `compute_langs` at the moment.") - cfg.ctc_decoding.compute_timestamps = cfg.compute_timestamps if cfg.extract_nbest: cfg.ctc_decoding.beam.return_best_hypothesis = False cfg.return_hypotheses = True @@ -379,7 +370,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis item = json.loads(line) if "duration" not in item: raise ValueError( - f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." ) total_duration += item["duration"] @@ -396,6 +388,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis override_cfg.augmentor = augmentor override_cfg.text_field = cfg.gt_text_attr_name override_cfg.lang_field = cfg.gt_lang_attr_name + override_cfg.timestamps = cfg.timestamps if hasattr(override_cfg, "prompt"): override_cfg.prompt = parse_multitask_prompt(OmegaConf.to_container(cfg.prompt)) @@ -433,7 +426,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis model_name, filepaths=filepaths, compute_langs=compute_langs, - compute_timestamps=compute_timestamps, + timestamps=cfg.timestamps, ) logging.info(f"Finished writing predictions to {output_filename}!") diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 268438c2e09d..f18fe02d2ed8 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -40,7 +40,6 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.multitask_decoding import MultiTaskDecoding, MultiTaskDecodingConfig from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier -from nemo.collections.asr.parts.utils import manifest_utils from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.common import tokenizers from nemo.collections.common.data.lhotse.dataloader import get_lhotse_dataloader_from_config @@ -68,6 +67,9 @@ def lens_to_mask(lens, max_length): + """ + Create a mask from a tensor of lengths. + """ batch_size = lens.shape[0] mask = torch.arange(max_length).repeat(batch_size, 1).to(lens.device) < lens[:, None] return mask @@ -222,7 +224,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.val_loss = GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True) - # TODO: PytorchMetrics lets you join two metrics together to save compute. But need to make wer and bleu have same outputs first + # TODO: PytorchMetrics lets you join two metrics together to save compute. + # But need to make wer and bleu have same outputs first self.wer = WER(self.decoding, log_prediction=self.cfg.get("log_prediction")) self.bleu = BLEU( self.decoding, tokenize=self.cfg.get('bleu_tokenizer', "13a"), log_prediction=False @@ -270,13 +273,15 @@ def change_vocabulary( prompt_format: Optional[str] = None, ): """ - Changes vocabulary used during AED decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during AED decoding process. Use this method when fine-tuning on + from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. decoding_cfg: A config for the decoding, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. @@ -291,7 +296,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this\ + tokenizer type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -457,13 +463,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[MultiTaskTranscriptionConfig] = None, **prompt, ) -> Union[List[str], List[Hypothesis]]: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path + to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -472,15 +480,30 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis + object (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class + for more details. Default is None and would retain the previous state set by using + self.change_decoding_strategy(). + Note: Currently its not supported for AED models. verbose: (bool) whether to display tqdm progress bar - override_config: (Optional[MultiTaskTranscriptionConfig]) A config to override the default config. - **prompt: Optional input to construct the prompts for the model. Accepted formats are: 1) legacy Canary-1B API source_lang=, target_lang=, etc. 2) explicit single-turn role=, slots={: , ...} 3) explicit multi-turn: turns=[{"role": , "slots": {: , ...}}] + override_config: (Optional[MultiTaskTranscriptionConfig]) A config to override the + default config. + **prompt: Optional input to construct the prompts for the model. Accepted formats are: + 1) legacy Canary-1B API source_lang=, target_lang=, etc. + 2) explicit single-turn role=, slots={: , ...} + 3) explicit multi-turn: turns=[{"role": , "slots": {: , ...}}] Returns: - A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order + as paths2audio_files """ + if timestamps: + raise NotImplementedError("Computing timestamps are not supported for this model yet.") + if override_config is None: trcfg = MultiTaskTranscriptionConfig( batch_size=batch_size, @@ -889,7 +912,8 @@ def _transcribe_forward( ) @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a Hypothesis object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent \ + output format across all decoder types, such that a Hypothesis object is always returned.' ) def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType: """ diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 2e313ce3c928..79c22794de01 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -209,12 +209,14 @@ def change_vocabulary( """ Changes vocabulary of the tokenizer used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. + For example, you would use it if you want to use pretrained encoder when fine-tuning on a + data in another language, or when you'd need model to learn capitalization, punctuation + and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Either `agg`, `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, whereas `wpe` is used for `BertTokenizer`. new_tokenizer_cfg: A config for the new tokenizer. if provided, pre-empts the dir and type @@ -227,7 +229,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer \ + type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -307,13 +310,14 @@ def change_vocabulary( logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during CTC decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: Whether to print the new config or not. """ if decoding_cfg is None: # Assume same decoding config as before @@ -343,7 +347,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: @@ -378,7 +383,7 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="stt_en_citrinet_256_gamma_0_25", - description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_en_citrinet_256_gamma_0_25", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:\nemo:stt_en_citrinet_256_gamma_0_25", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_en_citrinet_256_gamma_0_25/versions/1.0.0/files/stt_en_citrinet_256_gamma_0_25.nemo", ) results.append(model) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index edf4f84a9f9b..993c7dc6b298 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Union @@ -23,7 +21,6 @@ from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer from torch.utils.data import DataLoader -from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -37,6 +34,7 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -45,6 +43,7 @@ from nemo.utils import logging from nemo.utils.decorators import deprecated + __all__ = ['EncDecCTCModel'] @@ -128,13 +127,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or + path to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -143,16 +144,41 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis + object (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class + for more details. Default is None and would retain the previous state set by + using self.change_decoding_strategy(). verbose: (bool) whether to display tqdm progress bar override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. Returns: - A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as + paths2audio_files """ + if timestamps is not None: + # else retain the decoder state (users can set it using change_decoding_strategy) + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = True + self.cfg.decoding.preserve_alignments = True + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + else: # This is done to ensure the state is preserved when decoding_strategy is set outside + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = self.cfg.decoding.get('compute_timestamps', False) + self.cfg.decoding.preserve_alignments = self.cfg.decoding.get('preserve_alignments', False) + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + return super().transcribe( audio=audio, batch_size=batch_size, @@ -161,6 +187,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, ) @@ -235,13 +262,14 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during CTC decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: (bool) whether to display logging information """ if decoding_cfg is None: # Assume same decoding config as before @@ -270,7 +298,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): # Automatically inject args from model config to dataloader config @@ -670,7 +699,8 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): return output @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a Hypothesis object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent output \ + format across all decoder types, such that a Hypothesis object is always returned.' ) def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> GenericTranscriptionType: logits = outputs.pop('logits') @@ -705,6 +735,14 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen # cleanup memory del logits, logits_len + if trcfg.timestamps: + current_hypotheses = process_timestamp_outputs( + current_hypotheses, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + hypotheses = [] if all_hyp is None: hypotheses += current_hypotheses @@ -767,7 +805,11 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="QuartzNet15x5Base-En", - description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.", + description="QuartzNet15x5 model trained on six datasets: LibriSpeech, Mozilla Common Voice \ + (validated clips from en_1488h_2019-12-10), WSJ, Fisher, Switchboard, and NSC Singapore English. \ + It was trained with Apex/Amp optimization level O1 for 600 epochs. The model achieves a WER of \ + 3.79% on LibriSpeech dev-clean, and a WER of 10.05% on dev-other. Please visit \ + https://ngc.nvidia.com/catalog/models/nvidia:nemospeechmodels for further details.", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemospeechmodels/versions/1.0.0a5/files/QuartzNet15x5Base-En.nemo", ) results.append(model) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 089c34d98884..1d437a19a86b 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -253,10 +253,11 @@ def change_vocabulary( ctc_decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on + from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) @@ -415,7 +416,9 @@ def change_vocabulary( logging.info(f"Changed tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. Args: @@ -424,6 +427,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a model having both RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + verbose: bool whether to display change of decoder config or not. """ if decoder_type is None or decoder_type == 'rnnt': if decoding_cfg is None: @@ -466,7 +470,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.decoding = decoding_cfg self.cur_decoder = "rnnt" - logging.info(f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info( + f"Changed decoding strategy of the RNNT decoder to \n{OmegaConf.to_yaml(self.cfg.decoding)}" + ) elif decoder_type == 'ctc': if not hasattr(self, 'ctc_decoding'): @@ -497,9 +504,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.aux_ctc.decoding = decoding_cfg self.cur_decoder = "ctc" - logging.info( - f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" - ) + if verbose: + logging.info( + f"Changed decoding strategy of the CTC decoder to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}" + ) else: raise ValueError(f"decoder_type={decoder_type} is not supported. Supported values: [ctc,rnnt]") diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index c14265325985..028073d7ca7f 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -31,6 +31,7 @@ from nemo.collections.asr.parts.mixins.transcription import TranscriptionReturnType from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.mixins import AccessMixin from nemo.utils import logging, model_utils @@ -104,6 +105,7 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: bool = False, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ @@ -120,8 +122,13 @@ def transcribe( return_hypotheses: (bool) Either return hypotheses or text With hypotheses can do some postprocessing like getting timestamp or rescoring num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of + channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). verbose: (bool) whether to display tqdm progress bar logprobs: (bool) whether to return ctc logits insted of hypotheses @@ -130,10 +137,29 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ - if self.cur_decoder not in ["ctc", "rnnt"]: - raise ValueError( - f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" - ) + + if timestamps is not None: + if self.cur_decoder not in ["ctc", "rnnt"]: + raise ValueError( + f"{self.cur_decoder} is not supported for cur_decoder. Supported values are ['ctc', 'rnnt']" + ) + decoding_cfg = self.cfg.aux_ctc.decoding if self.cur_decoder == "ctc" else self.cfg.decoding + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = True + decoding_cfg.preserve_alignments = True + self.change_decoding_strategy(decoding_cfg, decoder_type=self.cur_decoder, verbose=False) + else: + return_hypotheses = False + with open_dict(decoding_cfg): + decoding_cfg.compute_timestamps = False + decoding_cfg.preserve_alignments = False + self.change_decoding_strategy(decoding_cfg, decoder_type=self.cur_decoder, verbose=False) return super().transcribe( audio=audio, @@ -144,6 +170,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, ) @@ -201,6 +228,14 @@ def _transcribe_output_processing( # for logit, elen in zip(logits, encoded_len): # logits_list.append(logit[:elen]) + if trcfg.timestamps: + best_hyp = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + del logits, encoded_len hypotheses = [] @@ -221,10 +256,11 @@ def change_vocabulary( ctc_decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a + pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder + when fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ @@ -295,7 +331,9 @@ def change_vocabulary( logging.info(f"Changed the tokenizer of the CTC decoder to {self.ctc_decoder.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type: str = None): + def change_decoding_strategy( + self, decoding_cfg: DictConfig = None, decoder_type: str = None, verbose: bool = True + ): """ Changes decoding strategy used during RNNT decoding process. @@ -305,10 +343,11 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type decoder_type: (str) Can be set to 'rnnt' or 'ctc' to switch between appropriate decoder in a model having RNN-T and CTC decoders. Defaults to None, in which case RNN-T decoder is used. If set to 'ctc', it raises error if 'ctc_decoder' is not an attribute of the model. + verbose: (bool) whether to display logging information """ if decoder_type is None or decoder_type == 'rnnt': self.cur_decoder = "rnnt" - return super().change_decoding_strategy(decoding_cfg=decoding_cfg) + return super().change_decoding_strategy(decoding_cfg=decoding_cfg, verbose=verbose) assert decoder_type == 'ctc' and hasattr(self, 'ctc_decoder') if decoding_cfg is None: @@ -337,7 +376,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig = None, decoder_type self.cfg.aux_ctc.decoding = decoding_cfg self.cur_decoder = "ctc" - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.aux_ctc.decoding)}") + + return None # PTL-specific methods def training_step(self, batch, batch_nb): diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 9e09acd21a5d..25890ec716c8 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -344,13 +344,15 @@ def change_vocabulary( decoding_cfg: Optional[DictConfig] = None, ): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning on from pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning + on from pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when fine-tuning + on data in another language, or when you'd need model to learn capitalization, punctuation + and/or special characters. Args: - new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer + (if the tokenizer type is `agg`) new_tokenizer_type: Type of tokenizer. Can be either `agg`, `bpe` or `wpe`. decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. @@ -363,7 +365,8 @@ def change_vocabulary( new_tokenizer_cfg = new_tokenizer_dir else: raise ValueError( - f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer \ + type is: {new_tokenizer_type}' ) else: new_tokenizer_cfg = None @@ -451,13 +454,14 @@ def change_vocabulary( logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose: bool = True): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: A flag to enable/disable logging. """ if decoding_cfg is None: # Assume same decoding config as before @@ -498,7 +502,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index 2b319a3c7dec..ce3b6bc89bce 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -40,6 +40,7 @@ from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.asr.parts.utils.transcribe_utils import process_timestamp_outputs from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -247,13 +248,15 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. Args: - audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path to a manifest file. + audio: (a single or list) of paths to audio files or a np.ndarray/tensor audio array or path + to a manifest file. Can also be a dataloader object that provides values that can be consumed by the model. Recommended length per file is between 5 and 25 seconds. \ But it is possible to pass a few hours long file if enough GPU memory is available. @@ -265,9 +268,14 @@ def transcribe( decoding. This is useful for streaming rnnt decoding. If this is not None, then the length of this list should be equal to the length of the audio list. num_workers: (int) number of workers for DataLoader - channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels + from multi-channel audio. If set to `'average'`, it performs averaging across channels. + Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. verbose: (bool) whether to display tqdm progress bar + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. @@ -277,6 +285,25 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ + + if timestamps is not None: + if timestamps or (override_config is not None and override_config.timestamps): + logging.info( + "Timestamps requested, setting decoding timestamps to True. Capture them in Hypothesis object, \ + with output[0][idx].timestep['word'/'segment'/'char']" + ) + return_hypotheses = True + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = True + self.cfg.decoding.preserve_alignments = True + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + else: + return_hypotheses = False + with open_dict(self.cfg.decoding): + self.cfg.decoding.compute_timestamps = False + self.cfg.decoding.preserve_alignments = False + self.change_decoding_strategy(self.cfg.decoding, verbose=False) + return super().transcribe( audio=audio, batch_size=batch_size, @@ -285,6 +312,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, override_config=override_config, # Additional arguments partial_hypothesis=partial_hypothesis, @@ -292,10 +320,11 @@ def transcribe( def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): """ - Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a pre-trained model. - This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would - use it if you want to use pretrained encoder when fine-tuning on data in another language, or when you'd need - model to learn capitalization, punctuation and/or special characters. + Changes vocabulary used during RNNT decoding process. Use this method when fine-tuning a + pre-trained model. This method changes only decoder and leaves encoder and pre-processing + modules unchanged. For example, you would use it if you want to use pretrained encoder when + fine-tuning on data in another language, or when you'd need model to learn capitalization, + punctuation and/or special characters. Args: new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ @@ -381,13 +410,14 @@ def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[Di logging.info(f"Changed decoder to output to {self.joint.vocabulary} vocabulary.") - def change_decoding_strategy(self, decoding_cfg: DictConfig): + def change_decoding_strategy(self, decoding_cfg: DictConfig, verbose=True): """ Changes decoding strategy used during RNNT decoding process. Args: decoding_cfg: A config for the decoder, which is optional. If the decoding type needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + verbose: (bool) whether to display logging information """ if decoding_cfg is None: # Assume same decoding config as before @@ -428,7 +458,8 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): with open_dict(self.cfg.decoding): self.cfg.decoding = decoding_cfg - logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + if verbose: + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") def _setup_dataloader_from_config(self, config: Optional[Dict]): # Automatically inject args from model config to dataloader config @@ -901,7 +932,8 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): return output @deprecated( - explanation='The return type of args will be updated in the upcoming release to ensure a consistent output format across all decoder types, such that a "Hypothesis" object is always returned.' + explanation='The return type of args will be updated in the upcoming release to ensure a consistent \ + output format across all decoder types, such that a "Hypothesis" object is always returned.' ) def _transcribe_output_processing( self, outputs, trcfg: TranscribeConfig @@ -915,10 +947,17 @@ def _transcribe_output_processing( return_hypotheses=trcfg.return_hypotheses, partial_hypotheses=trcfg.partial_hypothesis, ) - # cleanup memory del encoded, encoded_len + if trcfg.timestamps: + best_hyp = process_timestamp_outputs( + best_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + all_hyp = process_timestamp_outputs( + all_hyp, self.encoder.subsampling_factor, self.cfg['preprocessor']['window_stride'] + ) + hypotheses = [] all_hypotheses = [] diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 3cb9ec13109b..e48d76a9b7a3 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -133,6 +133,7 @@ def __init__( residual_panes = [] encoder_layers = [] self.dense_residual = False + self._subsampling_factor = 1 for layer_idx, lcfg in enumerate(jasper): dense_res = [] if lcfg.get('residual_dense', False): @@ -181,6 +182,9 @@ def __init__( ) ) feat_in = lcfg['filters'] + self._subsampling_factor *= ( + int(lcfg['stride'][0]) if isinstance(lcfg['stride'], List) else int(lcfg['stride']) + ) self._feat_out = feat_in @@ -199,7 +203,9 @@ def forward(self, audio_signal, length): return s_input[-1], length def update_max_sequence_length(self, seq_length: int, device): - # Find global max audio length across all nodes + """ + Find global max audio length across all nodes in distributed training and update the max_audio_length + """ if torch.distributed.is_initialized(): global_max_len = torch.tensor([seq_length], dtype=torch.float32, device=device) @@ -229,6 +235,10 @@ def update_max_sequence_length(self, seq_length: int, device): elif isinstance(m, SqueezeExcite): m.set_max_len(self.max_audio_length, seq_range=self.seq_range) + @property + def subsampling_factor(self) -> int: + return self._subsampling_factor + class ParallelConvASREncoder(NeuralModule, Exportable): """ @@ -426,7 +436,8 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary= if vocabulary is not None: if num_classes != len(vocabulary): raise ValueError( - f"If vocabulary is specified, it's length should be equal to the num_classes. Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" + f"If vocabulary is specified, it's length should be equal to the num_classes. \ + Instead got: num_classes={num_classes} and len(vocabulary)={len(vocabulary)}" ) self.__vocabulary = vocabulary self._feat_in = feat_in @@ -765,8 +776,8 @@ class SpeakerDecoder(NeuralModule, Exportable): Args: feat_in (int): Number of channels being input to this module num_classes (int): Number of unique speakers in dataset - emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings from 1st of this layers) - Defaults to [1024,1024] + emb_sizes (list) : shapes of intermediate embedding layers (we consider speaker embbeddings + from 1st of this layers). Defaults to [1024,1024] pool_mode (str) : Pooling strategy type. options are 'xvector','tap', 'attention' Defaults to 'xvector (mean and variance)' tap (temporal average pooling: just mean) diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 104e6bff81af..ac928fe99272 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -16,8 +16,7 @@ import os import tempfile from abc import ABC, abstractmethod -from collections.abc import Iterable -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -61,6 +60,7 @@ class TranscribeConfig: num_workers: Optional[int] = None channel_selector: ChannelSelectorType = None augmentor: Optional[DictConfig] = None + timestamps: Optional[bool] = None # returns timestamps for each word and segments if model supports punctuations verbose: bool = True # Utility @@ -86,7 +86,8 @@ def get_value_from_transcription_config(trcfg, key, default): return getattr(trcfg, key) else: logging.debug( - f"Using default value of {default} for {key} because it is not present in the transcription config {trcfg}." + f"Using default value of {default} for {key} because it is not present \ + in the transcription config {trcfg}." ) return default @@ -179,6 +180,7 @@ def transcribe( channel_selector: Optional[ChannelSelectorType] = None, augmentor: DictConfig = None, verbose: bool = True, + timestamps: Optional[bool] = None, override_config: Optional[TranscribeConfig] = None, **config_kwargs, ) -> GenericTranscriptionType: @@ -200,6 +202,9 @@ def transcribe( to `None`. Defaults to `None`. Uses zero-based indexing. augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. verbose: (bool) whether to display tqdm progress bar + timestamps: Optional(Bool): timestamps will be returned if set to True as part of hypothesis object + (output.timestep['segment']/output.timestep['word']). Refer to `Hypothesis` class for more details. + Default is None and would retain the previous state set by using self.change_decoding_strategy(). override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. **Note**: All other arguments in the function will be ignored if override_config is passed. You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. @@ -229,6 +234,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + timestamps=timestamps, **config_kwargs, ) else: diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index e52c3f46423e..da280a0c6b3c 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -67,13 +67,15 @@ class AbstractRNNTDecoding(ConfidenceMixin): rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated. Can take the following values - "char" for character/subword time stamps, "word" for word level - time stamps, "segment" for segment level time stamps and "all" (default), for character, word and segment level time stamps. + time stamps, "segment" for segment level time stamps and "all" (default), for character, + word and segment level time stamps. word_seperator: Str token representing the seperator between words. segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary + for forming the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -102,10 +104,10 @@ class AbstractRNNTDecoding(ConfidenceMixin): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated + and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -177,22 +179,23 @@ class AbstractRNNTDecoding(ConfidenceMixin): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set + and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin + of additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of expansions + (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). + This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -887,7 +890,7 @@ def _compute_offsets( # Construct the start and end indices brackets end_indices = np.asarray(token_repetitions).cumsum() - start_indices = np.concatenate(([start_index], end_indices[:-1])) + start_indices = np.concatenate(([int(start_index)], end_indices[:-1])) # Process the TxU dangling alignment tensor, containing pairs of (logits, label) alignment_labels = [al_logits_labels for al_logits_labels in hypothesis.text[1]] @@ -950,7 +953,8 @@ def _refine_timestamps_tdt( # Check if token is a punctuation mark # If so, set its start and end offset as start and end of the previous token - # This is done because there was observed a behaviour, when punctuation marks are predicted long after preceding token (i.e. after silence) + # This is done because there was observed a behaviour, when punctuation marks are predicted long + # after preceding token (i.e. after silence) if offset['char'][0] in supported_punctuation and i > 0: encoded_char_offsets[i]['start_offset'] = offset['start_offset'] = char_offsets[i - 1]['end_offset'] encoded_char_offsets[i]['end_offset'] = offset['end_offset'] = offset['start_offset'] @@ -1237,10 +1241,10 @@ class RNNTDecoding(AbstractRNNTDecoding): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated + and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1313,8 +1317,8 @@ class RNNTDecoding(AbstractRNNTDecoding): per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time. - alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. - If an integer is provided, it can decode sequences of that particular maximum length. + alsd_max_target_len: optional int or float, determines the potential maximum target sequence + length. If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1326,22 +1330,24 @@ class RNNTDecoding(AbstractRNNTDecoding): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of + expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving + accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1492,7 +1498,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for + forming the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -1521,10 +1528,10 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): The length of the list corresponds to the number of recognized words. exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. - aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. - Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and - attached to the regular frame confidence, + aggregation: Which aggregation type to use for collapsing per-token confidence into per-word + confidence. Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be + calculated and attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1594,8 +1601,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): per timestep of the acoustic model. Larger values will allow longer sentences to be decoded, at increased cost to execution time. - alsd_max_target_len: optional int or float, determines the potential maximum target sequence length. - If an integer is provided, it can decode sequences of that particular maximum length. + alsd_max_target_len: optional int or float, determines the potential maximum target sequence + length. If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1607,22 +1614,24 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 - in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the next step. - - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. - Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed - but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, - thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally - tuned on a validation set. + and affects the speed of inference since large values will perform large beam search in the + next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when + computing the expansions. The default (2.3) is selected from the paper. It performs a + comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the + Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore + provides a margin of additional tokens which can be potential candidates for expansion + apart from the "most likely" candidate. Lower values will reduce the number of expansions + (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher + values will increase the number of expansions (by reducing pruning-by-value, thereby + reducing speed but potentially improving accuracy). This is a hyper parameter to be + experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1750,7 +1759,8 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) else: logging.warning( - "Ignoring request for lang output in hypotheses since the model does not use an aggregate tokenizer" + "Ignoring request for lang output in hypotheses since the model does not use an aggregate\ + tokenizer" ) return hypotheses diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 415096a0c9d5..cb272e3d0462 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextMiniBatch -from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder from nemo.collections.asr.parts.preprocessing.features import normalize_batch from nemo.collections.asr.parts.preprocessing.segment import get_samples @@ -79,8 +79,8 @@ def longest_common_subsequence_merge(X, Y, filepath=None): Assumption is that the two chunks are consecutive chunks, and there exists at least small overlap acoustically. - It is a sub-word token merge algorithm, operating on the abstract notion of integer ids representing the subword ids. - It is independent of text or character encoding. + It is a sub-word token merge algorithm, operating on the abstract notion of integer ids representing + the subword ids. It is independent of text or character encoding. Since the algorithm is merge based, and depends on consecutive buffers, the very first buffer is processes using the "middle tokens" algorithm. @@ -292,8 +292,8 @@ def lcs_alignment_merge_buffer(buffer, data, delay, model, max_steps_per_timeste Merges the new text from the current frame with the previous text contained in the buffer. The alignment is based on a Longest Common Subsequence algorithm, with some additional heuristics leveraging - the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the merge - will be incorrect (or at least obtain worse WER overall). + the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the + merge will be incorrect (or at least obtain worse WER overall). """ # If delay timesteps is 0, that means no future context was used. Simply concatenate the buffer with new data. if delay < 1: @@ -327,8 +327,8 @@ def inplace_buffer_merge(buffer, data, timesteps, model): Merges the new text from the current frame with the previous text contained in the buffer. The alignment is based on a Longest Common Subsequence algorithm, with some additional heuristics leveraging - the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of the merge - will be incorrect (or at least obtain worse WER overall). + the notion that the chunk size is >= the context window. In case this assumptio is violated, the results of + the merge will be incorrect (or at least obtain worse WER overall). """ # If delay timesteps is 0, that means no future context was used. Simply concatenate the buffer with new data. if timesteps < 1: @@ -391,7 +391,7 @@ def __init__(self, asr_model, chunk_size, buffer_size): cfg.preprocessor.dither = 0.0 cfg.preprocessor.pad_to = 0 cfg.preprocessor.normalize = "None" - self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor) + self.raw_preprocessor = ASRModel.from_config_dict(cfg.preprocessor) self.raw_preprocessor.to(asr_model.device) def reset(self): @@ -756,7 +756,7 @@ def __init__( cfg.preprocessor.dither = 0.0 cfg.preprocessor.pad_to = 0 cfg.preprocessor.normalize = "None" - self.raw_preprocessor = EncDecCTCModelBPE.from_config_dict(cfg.preprocessor) + self.raw_preprocessor = ASRModel.from_config_dict(cfg.preprocessor) self.raw_preprocessor.to(asr_model.device) self.preprocessor = self.raw_preprocessor @@ -1091,12 +1091,15 @@ def _get_batch_preds(self): - For all samples, determine if signal has finished. - If so, skip calculation of mel-specs. - If not, compute mel spec and length - - Perform Encoder forward over this sub-batch of samples. Maintain the indices of samples that were processed. - - If performing stateful decoding, prior to decoder forward, remove the states of samples that were not processed. + - Perform Encoder forward over this sub-batch of samples. Maintain the indices of samples that + were processed. + - If performing stateful decoding, prior to decoder forward, remove the states of samples that + were not processed. - Perform Decoder + Joint forward for samples that were processed. - For all output RNNT alignment matrix of the joint do: - If signal has ended previously (this was last buffer of padding), skip alignment - - Otherwise, recalculate global index of this sample from the sub-batch index, and preserve alignment. + - Otherwise, recalculate global index of this sample from the sub-batch index, and preserve + alignment. - Same for preds - Update indices of sub-batch with global index map. - Redo steps until all samples were processed (sub-batch size == 0). @@ -1362,15 +1365,17 @@ def transcribe( class CacheAwareStreamingAudioBuffer: """ - A buffer to be used for cache-aware streaming. It can load a single or multiple audio files/processed signals, split them in chunks and return one on one. - It can be used to simulate streaming audio or audios. + A buffer to be used for cache-aware streaming. It can load a single or multiple audio + files/processed signals, split them in chunks and return one on one. It can be used to + simulate streaming audio or audios. """ def __init__(self, model, online_normalization=None, pad_and_drop_preencoded=False): ''' Args: model: An ASR model. - online_normalization (bool): whether to perform online normalization per chunk or normalize the whole audio before chunking + online_normalization (bool): whether to perform online normalization per chunk or + normalize the whole audio before chunking pad_and_drop_preencoded (bool): if true pad first audio chunk and always drop preencoded ''' self.model = model @@ -1430,7 +1435,8 @@ def __iter__(self): audio_chunk = self.buffer[:, :, self.buffer_idx : self.buffer_idx + chunk_size] if self.sampling_frames is not None: - # checking to make sure the audio chunk has enough frames to produce at least one output after downsampling + # checking to make sure the audio chunk has enough frames to produce at least one output after + # downsampling if self.buffer_idx == 0 and isinstance(self.sampling_frames, list): cur_sampling_frames = self.sampling_frames[0] else: diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 0d4f4c895bcf..189d98537d3f 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -199,7 +199,8 @@ def get_buffered_pred_feat_multitaskAED( if filepaths: logging.info( - "Deteced audio files as input, default to English ASR with Punctuation and Capitalization output. Please use manifest input for other options." + "Deteced audio files as input, default to English ASR with Punctuation and Capitalization output. \ + Please use manifest input for other options." ) for audio_file in tqdm(filepaths, desc="Transcribing:", total=len(filepaths), ncols=80): meta = { @@ -281,12 +282,16 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: - append_pred (bool): Flag indicating whether to append predictions to an existing dataset. - audio_type (str): Type of audio files to consider. - dataset_manifest (str): Path to the dataset manifest file. - - audio_key (str, optional): Key in the manifest file specifying the audio file path. Defaults to 'audio_filepath'. - - presort_manifest (bool, optional): Flag indicating whether to presort the manifest file. Defaults to True. + - audio_key (str, optional): Key in the manifest file specifying the audio file path. + Defaults to 'audio_filepath'. + - presort_manifest (bool, optional): Flag indicating whether to presort the manifest file. + Defaults to True. Returns: Tuple[List[str], bool]: A tuple containing the following: - - filepaths (List[str]): List of filepaths to the audio files if path to the directory containing audio files is provided. - - sorted_manifest_path (bool): Path to the sorted manifest file if path to the dataset manifest file is provided. + - filepaths (List[str]): List of filepaths to the audio files if path to the directory + containing audio files is provided. + - sorted_manifest_path (bool): Path to the sorted manifest file if path to the dataset + manifest file is provided. """ filepaths = None @@ -308,7 +313,8 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: item[audio_key] = get_full_path(item[audio_key], cfg.dataset_manifest) if item.get("duration") is None and cfg.presort_manifest: raise ValueError( - f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." + f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} \ + lacks a 'duration' field." ) with NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: @@ -388,7 +394,7 @@ def write_transcription( model_name: str, filepaths: List[str] = None, compute_langs: bool = False, - compute_timestamps: bool = False, + timestamps: bool = False, ) -> Tuple[str, str]: """Write generated transcription to output file.""" if cfg.append_pred: @@ -433,7 +439,7 @@ def write_transcription( else: # transcription is Hypothesis item = {'audio_filepath': filepaths[idx], pred_text_attr_name: transcription.text} - if compute_timestamps: + if timestamps: timestamps = transcription.timestep if timestamps is not None and isinstance(timestamps, dict): timestamps.pop( @@ -441,7 +447,7 @@ def write_transcription( ) # Pytorch tensor calculating index of each token, not needed. for key in timestamps.keys(): values = normalize_timestamp_output(timestamps[key]) - item[f'timestamps_{key}'] = values + item[f'{key}'] = values if compute_langs: item['pred_lang'] = transcription.langs @@ -458,7 +464,7 @@ def write_transcription( else: # transcription is Hypothesis item[pred_text_attr_name] = best_hyps[idx].text - if compute_timestamps: + if timestamps: timestamps = best_hyps[idx].timestep if timestamps is not None and isinstance(timestamps, dict): timestamps.pop( @@ -466,7 +472,7 @@ def write_transcription( ) # Pytorch tensor calculating index of each token, not needed. for key in timestamps.keys(): values = normalize_timestamp_output(timestamps[key]) - item[f'timestamps_{key}'] = values + item[f'{key}'] = values if compute_langs: item['pred_lang'] = best_hyps[idx].langs @@ -492,10 +498,14 @@ def compute_metrics_per_sample( Args: manifest_path: str, Required - path to dataset JSON manifest file (in NeMo format) - reference_field: str, Optional - name of field in .json manifest with the reference text ("text" by default). - hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text ("pred_text" by default). - metrics: list[str], Optional - list of metrics to be computed (currently supported "wer", "cer", "punct_er") - punctuation_marks: list[str], Optional - list of punctuation marks for computing punctuation error rate ([".", ",", "?"] by default). + reference_field: str, Optional - name of field in .json manifest with the reference text + ("text" by default). + hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text + ("pred_text" by default). + metrics: list[str], Optional - list of metrics to be computed + (currently supported "wer", "cer", "punct_er") + punctuation_marks: list[str], Optional - list of punctuation marks for computing + punctuation error rate ([".", ",", "?"] by default). output_manifest_path: str, Optional - path where .json manifest with calculated metrics will be saved. Returns: @@ -568,6 +578,61 @@ def compute_metrics_per_sample( return samples_with_metrics +def process_timestamp_outputs(outputs, subsampling_factor: int = 1, window_stride: float = 0.01): + """ + Process the timestamps from list of hypothesis to user friendly format. + Converts the start and end duration from frames to seconds. + Args: + outputs: List of Hypothesis objects. + subsampling_factor: int, Subsampling factor used in the model. + window_stride: float, Window stride used in the model. (sometimes referred to as hop length/shift) + Returns: + List of Hypothesis objects with processed timestamps + + """ + + if outputs is None: + return outputs + + if isinstance(outputs, rnnt_utils.Hypothesis): + outputs = [outputs] + + if not isinstance(outputs[0], rnnt_utils.Hypothesis): + raise ValueError(f"Expected Hypothesis object, got {type(outputs[0])}") + + def process_timestamp(timestamp, subsampling_factor, window_stride): + """ + Process the timestamp for a single hypothesis. + return the start and end duration in seconds. + """ + for idx, val in enumerate(timestamp): + start_offset = val['start_offset'] + end_offset = val['end_offset'] + start = start_offset * window_stride * subsampling_factor + end = end_offset * window_stride * subsampling_factor + val['start'] = start + val['end'] = end + + return timestamp + + for idx, hyp in enumerate(outputs): + if not hasattr(hyp, 'timestep'): + raise ValueError( + f"Expected Hypothesis object to have 'timestep' attribute, when compute_timestamps is \ + enabled but got {hyp}" + ) + timestep = hyp.timestep + if 'word' in timestep: + outputs[idx].timestep['word'] = process_timestamp(timestep['word'], subsampling_factor, window_stride) + if 'char' in timestep: + outputs[idx].timestep['char'] = process_timestamp(timestep['char'], subsampling_factor, window_stride) + if 'segment' in timestep: + outputs[idx].timestep['segment'] = process_timestamp( + timestep['segment'], subsampling_factor, window_stride + ) + return outputs + + class PunctuationCapitalization: def __init__(self, punctuation_marks: str): """ diff --git a/tests/collections/asr/conftest.py b/tests/collections/asr/conftest.py index dba29f949fb0..a9bc13153164 100644 --- a/tests/collections/asr/conftest.py +++ b/tests/collections/asr/conftest.py @@ -19,6 +19,8 @@ import pytest import torch +from nemo.collections.asr.models import ASRModel + class RNNTTestHelper: @staticmethod @@ -353,3 +355,18 @@ def rnnt_test_helper() -> Type[RNNTTestHelper]: @pytest.fixture(scope="session") def rnn_loss_sample_data() -> Type[RnntLossSampleData]: return RnntLossSampleData + + +@pytest.fixture(scope='session') +def fast_conformer_transducer_model(): + return ASRModel.from_pretrained("stt_en_fastconformer_transducer_large") + + +@pytest.fixture(scope='session') +def fast_conformer_ctc_model(): + return ASRModel.from_pretrained("stt_en_fastconformer_ctc_large") + + +@pytest.fixture(scope='session') +def fast_conformer_hybrid_model(): + return ASRModel.from_pretrained("parakeet-tdt_ctc-110m") diff --git a/tests/collections/asr/mixins/test_transcription.py b/tests/collections/asr/mixins/test_transcription.py index 1a6f38681d0c..6e2d5fe16c68 100644 --- a/tests/collections/asr/mixins/test_transcription.py +++ b/tests/collections/asr/mixins/test_transcription.py @@ -23,7 +23,6 @@ from torch.utils.data import DataLoader, Dataset from nemo.collections.asr.data.audio_to_text import _speech_collate_fn -from nemo.collections.asr.models import ASRModel from nemo.collections.asr.parts.mixins import TranscribeConfig, TranscriptionMixin from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType from nemo.collections.asr.parts.utils import Hypothesis @@ -44,6 +43,23 @@ def forward(self, x): return out +@pytest.mark.with_downloads() +@pytest.fixture() +def audio_files(test_data_dir): + """ + Returns a list of audio files for testing. + """ + import soundfile as sf + + audio_file1 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") + audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav") + + audio1, _ = sf.read(audio_file1, dtype='float32') + audio2, _ = sf.read(audio_file2, dtype='float32') + + return audio1, audio2 + + class TranscribableDummy(DummyModel, TranscriptionMixin): def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): super()._transcribe_on_begin(audio, trcfg) @@ -297,12 +313,11 @@ class OverrideConfig(TranscribeConfig): pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_return_hypothesis(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") + def test_transcribe_return_hypothesis(self, test_data_dir, fast_conformer_ctc_model): audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - # Numpy array test - outputs = model.transcribe(audio_file, batch_size=1, return_hypotheses=True) + # Audio file test + outputs = fast_conformer_ctc_model.transcribe(audio_file, batch_size=1, return_hypotheses=True) assert len(outputs) == 1 assert isinstance(outputs[0], Hypothesis) @@ -313,62 +328,82 @@ def test_transcribe_return_hypothesis(self, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_tensor(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_tensor(self, audio_files, fast_conformer_ctc_model): + audio, _ = audio_files # Numpy array test - outputs = model.transcribe(audio, batch_size=1) + outputs = fast_conformer_ctc_model.transcribe(audio, batch_size=1) assert len(outputs) == 1 assert isinstance(outputs[0], str) @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_multiple_tensor(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_multiple_tensor(self, audio_files, fast_conformer_ctc_model): - audio_file_2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an104-mrcb-b.wav") - audio_2, sr = sf.read(audio_file_2, dtype='float32') + audio, audio_2 = audio_files # Mix second audio to torch.tensor() audio_2 = torch.tensor(audio_2) # Numpy array test - outputs = model.transcribe([audio, audio_2], batch_size=2) + outputs = fast_conformer_ctc_model.transcribe([audio, audio_2], batch_size=2) assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) @pytest.mark.with_downloads() @pytest.mark.unit - def test_transcribe_dataloader(self, test_data_dir): - model = ASRModel.from_pretrained("stt_en_conformer_ctc_small") - - # Load audio file - import soundfile as sf - - audio_file = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an46-mmap-b.wav") - audio, sr = sf.read(audio_file, dtype='float32') + def test_transcribe_dataloader(self, audio_files, fast_conformer_ctc_model): - audio_file2 = os.path.join(test_data_dir, "asr", "train", "an4", "wav", "an152-mwhw-b.wav") - audio2, sr = sf.read(audio_file2, dtype='float32') + audio, audio2 = audio_files dataset = DummyDataset([audio, audio2]) collate_fn = lambda x: _speech_collate_fn(x, pad_id=0) dataloader = DataLoader(dataset, batch_size=2, shuffle=False, num_workers=0, collate_fn=collate_fn) # DataLoader test - outputs = model.transcribe(dataloader, batch_size=1) + outputs = fast_conformer_ctc_model.transcribe(dataloader, batch_size=1) assert len(outputs) == 2 assert isinstance(outputs[0], str) assert isinstance(outputs[1], str) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_timestamps_with_transcribe(self, audio_files, fast_conformer_ctc_model): + audio1, audio2 = audio_files + + output = fast_conformer_ctc_model.transcribe([audio1, audio2], timestamps=True) + + # check len of output + assert len(output) == 2 + + # check hypothesis object + assert isinstance(output[0], Hypothesis) + # check transcript + assert output[0].text == 'stop' + assert output[1].text == 'start' + + # check timestamp + assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.4) + assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.48) + + @pytest.mark.with_downloads() + @pytest.mark.unit + def test_timestamps_with_transcribe_hybrid(self, audio_files, fast_conformer_hybrid_model): + audio1, audio2 = audio_files + + output = fast_conformer_hybrid_model.transcribe([audio1, audio2], timestamps=True) + + # check len of output + assert len(output) == 2 + + output = output[1] # Transducer returns tuple + + # check hypothesis object + assert isinstance(output[0], Hypothesis) + # check transcript + assert output[0].text == 'Stop?' + assert output[1].text == 'Start.' + + # check timestamp + assert output[0].timestep['segment'][0]['start'] == pytest.approx(0.48) + assert output[0].timestep['segment'][0]['end'] == pytest.approx(0.72) From 5c5b023da44a19ab94a954f2e00cad16e3e807ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Mon, 11 Nov 2024 10:25:13 +0100 Subject: [PATCH 02/24] =?UTF-8?q?[=F0=9F=A4=A0]:=20Howdy=20folks,=20let's?= =?UTF-8?q?=20bump=20`Dockerfile.ci`=20to=201b8fce7=20!=20(#11247)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index dddae1a8ec9f..4b259cc953bc 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=bc8c4f356240ea4ccadce426251171e6e430c9d3 +ARG MCORE_TAG=1b8fce7e17e7f945c1f59d06744a2e126bedc015 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ From 66766b18709fc5e81fa93f3214472d28da17e4bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 12 Nov 2024 08:39:09 +0100 Subject: [PATCH 03/24] =?UTF-8?q?[=F0=9F=A4=A0]:=20Howdy=20folks,=20let's?= =?UTF-8?q?=20bump=20`Dockerfile.ci`=20to=2047ff44e=20!=20(#11254)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index 4b259cc953bc..ee04c79cd2ba 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=1b8fce7e17e7f945c1f59d06744a2e126bedc015 +ARG MCORE_TAG=47ff44e5b98061bf81295ce7df899ee62529d5e3 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ From d32c664e461956eefdaeea8058fafbfb4b7a3dbd Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Tue, 12 Nov 2024 17:45:13 +0100 Subject: [PATCH 04/24] Handling tokenizer in PTQ for Nemo 2.0 (#11237) * Handling tokenizer in PTQ for Nemo 2.0 Signed-off-by: Jan Lasek * Print log msg and enable overriding Signed-off-by: Jan Lasek * Warning for legacy tokenizer config Signed-off-by: Jan Lasek * Save HF tokenizer to make tokenizer_config.yaml (almost) redundant Signed-off-by: Jan Lasek * Handle tokenizer in a unified way Signed-off-by: Jan Lasek * Move saving context within export Signed-off-by: Jan Lasek * Fix typo in get_tokenzier Signed-off-by: Jan Lasek * Reduce diff Signed-off-by: Jan Lasek * Drop unused import Signed-off-by: Jan Lasek --------- Signed-off-by: Jan Lasek --- .../collections/llm/quantization/quantizer.py | 21 +++++++++++-------- nemo/export/tensorrt_llm.py | 15 +++++++++---- .../trt_llm/nemo_ckpt_loader/nemo_file.py | 9 ++++---- nemo/export/trt_llm/qnemo/tokenizer_utils.py | 5 ----- nemo/utils/model_utils.py | 4 ++++ scripts/llm/ptq.py | 2 +- 6 files changed, 33 insertions(+), 23 deletions(-) diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 15367cb25aba..2f3e0e1e986e 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import shutil from dataclasses import dataclass from typing import Optional, Union @@ -22,6 +23,7 @@ from tqdm import tqdm from nemo.collections import llm +from nemo.lightning.ckpt_utils import CONTEXT_PATH from nemo.utils import logging from .utils import get_unwrapped_mcore_model @@ -259,7 +261,7 @@ def loop(model): return loop - def export(self, model: llm.GPTModel) -> None: + def export(self, model: llm.GPTModel, model_dir: str) -> None: assert self.export_config is not None, "Export config is not set" # TODO: Add sample generate # TODO: Support megatron_amp_O2 @@ -277,15 +279,16 @@ def export(self, model: llm.GPTModel) -> None: use_nfs_workspace=use_nfs_workspace, ) - dist.barrier() # Wait until all ranks complete export_model_config step - logging.info(f"Export succeeded, model has been exported to {export_dir}. Saving tokenizer if possible...") + # Save the model context in order to restore its tokenizer later. The destination + # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. + shutil.copytree( + os.path.join(model_dir, CONTEXT_PATH), + os.path.join(export_dir, "nemo_context"), + dirs_exist_ok=True, + ) + logging.info(f"Model context saved.") - if dist.get_rank() == 0: - try: - tokenizer_dst = os.path.join(export_dir, 'tokenizer') - model.tokenizer.tokenizer.save_pretrained(tokenizer_dst) - except Exception as err: - logging.warning("Could not save the tokenizer: " + str(err)) + logging.info(f"Export succeeded, model has been exported to {export_dir}.") def get_calib_data_iter( diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index fb43224d59a9..08b0b822cad4 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -37,12 +37,12 @@ from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, - get_tokenzier, + get_tokenizer, is_nemo_file, load_nemo_model, ) from nemo.export.trt_llm.qnemo import qnemo_to_tensorrt_llm -from nemo.export.trt_llm.qnemo.tokenizer_utils import get_nmt_tokenizer +from nemo.export.trt_llm.qnemo.tokenizer_utils import TOKENIZER_CONFIG_FILE, get_nmt_tokenizer from nemo.export.trt_llm.qnemo.utils import is_qnemo_checkpoint from nemo.export.trt_llm.tensorrt_llm_build import build_and_save_engine from nemo.export.trt_llm.tensorrt_llm_run import ( @@ -294,7 +294,14 @@ def export( else: unpack_tarball(nemo_checkpoint_path, tmp_dir.name) nemo_checkpoint_path = tmp_dir.name - self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) + + if os.path.exists(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)): + # Instantiate tokenizer for a legacy "Nemo 1" quantized checkpoint from a tokenizer config. + # Note that using the config is deprecated and it will be removed in future releases. + LOGGER.warning("Detected legacy tokenizer_config.yaml, using it to build tokenizer.") + self.tokenizer = get_nmt_tokenizer(nemo_checkpoint_path) + else: + self.tokenizer = get_tokenizer(nemo_checkpoint_path) qnemo_to_tensorrt_llm( nemo_checkpoint_path=nemo_checkpoint_path, @@ -1092,7 +1099,7 @@ def _load(self): if len(folders) > 0: try: self._load_config_file() - self.tokenizer = get_tokenzier(Path(os.path.join(self.model_dir))) + self.tokenizer = get_tokenizer(self.model_dir) self.model = load( tokenizer=self.tokenizer, engine_dir=self.model_dir, diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 407a7ce600c9..23d227d32acf 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -283,16 +283,17 @@ def copy_tokenizer_files(config, out_dir): outfile.write(infile.read()) -def get_tokenzier(tokenizer_dir_or_path: Path) -> PreTrainedTokenizer: - """Loads the tokenizer from the decoded NEMO weights dir.""" +def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer: + """Loads the tokenizer from the decoded NeMo weights dir.""" + tokenizer_dir_or_path = Path(tokenizer_dir_or_path) if (tokenizer_dir_or_path / "nemo_context").exists(): from nemo.lightning import io tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer") return build_tokenizer(tokenizer_spec) else: - if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")): - return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")) + if (tokenizer_dir_or_path / "huggingface_tokenizer").is_dir(): + return AutoTokenizer.from_pretrained(tokenizer_dir_or_path / "huggingface_tokenizer") model_path = ( tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index 36efa9259f9d..beca40bcd3d7 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -29,11 +29,6 @@ def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" - tokenizer_dir = os.path.join(nemo_checkpoint_path, TOKENIZER_DIR) - if os.path.exists(tokenizer_dir): - print(f"Initializing tokenizer from {TOKENIZER_DIR} directory") - return AutoTokenizer.from_pretrained(tokenizer_dir) - print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index b417c088b22e..5d7d019c6099 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -724,6 +724,10 @@ def save_artifacts(model, output_dir: str, use_abspath: bool = False) -> None: app_state = AppState() model_file = app_state.model_restore_path model_cfg = copy.deepcopy(model.cfg) + + if model_cfg.tokenizer.library == "huggingface": + model.tokenizer.save_pretrained(os.path.join(output_dir, "huggingface_tokenizer")) + if not hasattr(model, "artifacts"): if hasattr(model_cfg, "tokenizer"): OmegaConf.save(model_cfg.tokenizer, os.path.join(output_dir, "tokenizer_config.yaml")) diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index 0fd2c5682e8a..c04d32290e5f 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -92,7 +92,7 @@ def main(): quantizer = quantization.Quantizer(quantization_config, export_config) model = quantization.load_with_modelopt_layer_spec(args.nemo_checkpoint, args.calib_tp, args.calib_pp) model = quantizer.quantize(model) - quantizer.export(model) + quantizer.export(model, args.nemo_checkpoint) if __name__ == '__main__': From 34c303284b64c3fe7131b0a7b7fa917f1b758f2b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 12 Nov 2024 11:51:44 -0500 Subject: [PATCH 05/24] Fix finetuning datamodule resume (#11187) * fix datamodule resume Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * fix subclass Signed-off-by: Chen Cui * docstrings and formats Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Co-authored-by: cuichenx --- examples/llm/sft/hf.py | 2 +- nemo/collections/llm/gpt/data/fine_tuning.py | 98 +++++++++++++++++--- 2 files changed, 84 insertions(+), 16 deletions(-) diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 39efe87de368..7d4cde7866a2 100644 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -22,7 +22,7 @@ class SquadDataModuleWithPthDataloader(llm.SquadDataModule): - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: return DataLoader( dataset, num_workers=self.num_workers, diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index d7ed08a01ed4..9d16ea8aa021 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -22,6 +22,7 @@ from nemo.collections.common.tokenizers import AutoTokenizer from nemo.collections.llm.gpt.data.core import create_sft_dataset +from nemo.lightning.data import WrappedDataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging @@ -34,22 +35,26 @@ class FineTuningDataModule(pl.LightningDataModule): """Base class for fine-tuning an LLM. This class provides a foundation for building custom data modules for fine-tuning Nemo NLP models. It inherits from - `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch creation - for training, validation, and testing. + `pl.LightningDataModule` from the PyTorch Lightning library and handles data loading, preprocessing, and batch + creation for training, validation, and testing. Args: dataset_root (Union[str, Path]): The root directory containing the training, validation, and test data. seq_length (int, optional): The maximum sequence length for the input and output text. Defaults to 2048. - tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. Defaults to None. + tokenizer (Optional[TokenizerSpec], optional): The tokenizer to use for preprocessing the text. If not provided, a Megatron GPT2 BPE tokenizer will be used. micro_batch_size (int, optional): The micro batch size for training. Defaults to 4. global_batch_size (int, optional): The global batch size for training. Defaults to 8. - rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. Defaults to None. + rampup_batch_size (Optional[List[int]], optional): A list of batch sizes for ramping up during training. + Defaults to None. seed (int, optional): The random seed for data shuffling. Defaults to 1234. - memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. Defaults to 1. + memmap_workers (int, optional): The number of worker processes for loading data using TextMemMapDataset. + Defaults to 1. num_workers (int, optional): The number of worker processes for data loading. Defaults to 8. - pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. Defaults to True. - persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. Defaults to False. + pin_memory (bool, optional): Whether to pin memory during data loading for faster GPU training. + Defaults to True. + persistent_workers (bool, optional): Whether to keep data loading workers persistent across epochs. + Defaults to False. packed_sequence_specs (PackedSequenceSpecs, optional): See PackedSequenceSpecs for details dataset_kwargs (Optional[Dict[str, Any]], optional): Keyword arguments to pass into the GPTSFTDataset class """ @@ -90,18 +95,28 @@ def __init__( self.dataset_kwargs = dataset_kwargs or {} def validate_batch_size_for_packed_sequence(self): + """ + Validate that micro batch size must be 1 when using packed sequence. + """ if self.packed_sequence_size > 0 and self.micro_batch_size > 1: raise ValueError( "Micro batch size should be 1 when training with packed sequence, but your micro batch size " f"is {self.micro_batch_size}. \nThe following config is equivalent to your current setting for " f"a packed dataset. Please update your config to the following: \n" f"Set micro batch size to 1 (currently {self.micro_batch_size})\n" - f"Set global batch size to {self.global_batch_size // self.micro_batch_size} (currently {self.global_batch_size}) \n" - f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} (currently {self.packed_sequence_size}) \n" - f"For details please visit https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/sequence_packing.html" + f"Set global batch size to {self.global_batch_size // self.micro_batch_size} " + f"(currently {self.global_batch_size}) \n" + f"Set packed sequence length to {self.packed_sequence_size*self.micro_batch_size} " + f"(currently {self.packed_sequence_size}) \n" + f"For details please visit " + f"https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/optimizations/" + f"sequence_packing.html" ) def prepare_data(self) -> None: + """ + Prepare packed sequence data + """ if self.packed_sequence_size > 0 and not self.train_path_packed.is_file(): from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data @@ -115,6 +130,9 @@ def prepare_data(self) -> None: ) def setup(self, stage: str): + """Called by pytorch lightning in datamodule setup""" + + # data_sampler is used in `setup_data_sampler` in MegatronStrategy.setup self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, micro_batch_size=self.micro_batch_size, @@ -127,36 +145,78 @@ def setup(self, stage: str): # base_dataset_utils.get_datasets_weights_and_num_samples self.max_train_samples = int(math.ceil(self.global_batch_size * self.trainer.max_steps * 1.005)) + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save datamodule state. + + Returns: + A dictionary containing datamodule state. + + """ + consumed_samples = self.data_sampler.compute_consumed_samples( + self.trainer.global_step - self.data_sampler.init_global_step + ) + return {"consumed_samples": consumed_samples} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat + + Args: + state_dict: the datamodule state returned by ``state_dict``. + + """ + try: + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + consumed_samples = state_dict["consumed_samples"] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + + update_num_microbatches( + consumed_samples=consumed_samples, + consistency_check=False, + ) + self.data_sampler.if_first_step = 1 + def train_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.train_path if self.packed_sequence_size <= 0 else self.train_path_packed, max_num_samples=self.max_train_samples, **self.dataset_kwargs, - ) + ), + mode="train", ) def val_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.validation_path, is_test=True, **self.dataset_kwargs, ), + mode="validation", ) def test_dataloader(self) -> DataLoader: + # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( self.test_path, tokens_to_generate=32, is_test=True, **self.dataset_kwargs, - ) + ), + mode="test", ) @lru_cache def _create_dataset(self, path, is_test=False, **kwargs): + # pylint: disable=C0115,C0116 return create_sft_dataset( path, tokenizer=self.tokenizer, @@ -167,9 +227,11 @@ def _create_dataset(self, path, is_test=False, **kwargs): **kwargs, ) - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: - return DataLoader( - dataset, + def _create_dataloader(self, dataset, mode, **kwargs) -> DataLoader: + # pylint: disable=C0115,C0116 + return WrappedDataLoader( + mode=mode, + dataset=dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, @@ -179,10 +241,13 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: @property def train_path(self) -> Path: + """Path to training dataset file""" return self.dataset_root / "training.jsonl" @property def train_path_packed(self) -> Path: + """Path to training dataset file for packed sequence. The file path contains a reference to the + tokenizer/model name since packed sequence dataset consists of tokenized indices.""" if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_data_path is not None: return self.packed_sequence_specs.packed_data_path @@ -195,13 +260,16 @@ def train_path_packed(self) -> Path: @property def validation_path(self) -> Path: + """Path to validation dataset file""" return self.dataset_root / "validation.jsonl" @property def test_path(self) -> Path: + """Path to test dataset file""" return self.dataset_root / "test.jsonl" def _extract_tokenizer_model_name(self) -> str: + """Automatically get the model name from model path.""" if self.packed_sequence_specs.tokenizer_model_name is not None: tokenizer_model_name = self.packed_sequence_specs.tokenizer_model_name elif isinstance(self.tokenizer, AutoTokenizer): From d363e5dff6e888ce79f4aa7a7cd1f4cf51e6b0cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 12 Nov 2024 18:01:11 +0100 Subject: [PATCH 06/24] ci: Move `bump mcore` to templates (#11229) * ci: Move `bump mcore` to templates Signed-off-by: Oliver Koenig * fix Signed-off-by: Oliver Koenig * fix Signed-off-by: Oliver Koenig * fix Signed-off-by: Oliver Koenig * final Signed-off-by: Oliver Koenig --------- Signed-off-by: Oliver Koenig --- .github/workflows/mcore-tag-bump-bot.yml | 63 +++++------------------- 1 file changed, 12 insertions(+), 51 deletions(-) diff --git a/.github/workflows/mcore-tag-bump-bot.yml b/.github/workflows/mcore-tag-bump-bot.yml index 13f4059a3a6b..1b0712924101 100644 --- a/.github/workflows/mcore-tag-bump-bot.yml +++ b/.github/workflows/mcore-tag-bump-bot.yml @@ -6,54 +6,15 @@ on: - cron: 0 0 * * * jobs: - main: - runs-on: ubuntu-latest - environment: main - steps: - - name: Checkout NVIDIA/Megatron-LM - uses: actions/checkout@v4 - with: - repository: NVIDIA/Megatron-LM - ref: main - path: ${{ github.run_id }} - - - name: Get latest mcore commit - id: ref - run: | - cd ${{ github.run_id }} - sha=$(git rev-parse origin/main) - echo "sha=${sha}" >> "$GITHUB_OUTPUT" - echo "short_sha=${sha:0:7}" >> "$GITHUB_OUTPUT" - echo "date=$(date +%F)" >> "$GITHUB_OUTPUT" - - - name: Checkout ${{ github.repository }} - uses: actions/checkout@v4 - with: - path: ${{ github.run_id }} - token: ${{ secrets.PAT }} - - - name: Bump MCORE_TAG - run: | - cd ${{ github.run_id }} - sed -i 's/^ARG MCORE_TAG=.*$/ARG MCORE_TAG=${{ steps.ref.outputs.sha }}/' Dockerfile.ci - - - name: Create Bump PR - uses: peter-evans/create-pull-request@v6 - id: create-pull-request - with: - path: ${{ github.run_id }} - branch: bump-ci-container-${{ steps.ref.outputs.date }} - base: main - title: 'Bump `Dockerfile.ci` (${{ steps.ref.outputs.date }})' - token: ${{ secrets.PAT }} - body: | - 🚀 PR to Bump `Dockerfile.ci`. - - 📝 Please remember the following to-do's before merge: - - [ ] Verify the presubmit CI - - 🙏 Please merge this PR only if the CI workflow completed successfully. - commit-message: "[🤠]: Howdy folks, let's bump `Dockerfile.ci` to ${{ steps.ref.outputs.short_sha }} !" - signoff: true - reviewers: 'pablo-garay' - labels: 'Run CICD' + mcore: + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_bump_dockerfile.yml@v0.11.0 + with: + source-repository: NVIDIA/Megatron-LM + source-ref: main + build-arg: MCORE_TAG + dockerfile: Dockerfile.ci + base-branch: main + cicd-label: Run CICD + pr-reviewers: 'pablo-garay' + secrets: + PAT: ${{ secrets.PAT }} \ No newline at end of file From 77c8e91d90ab4808ba8fa380e31e666d8e29ee75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 12 Nov 2024 18:17:06 +0100 Subject: [PATCH 07/24] fix: Update baseline (#11205) Signed-off-by: Oliver Koenig --- .github/workflows/secrets-detector.yml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/secrets-detector.yml b/.github/workflows/secrets-detector.yml index cf8ccc189ab6..d81b5638e31f 100644 --- a/.github/workflows/secrets-detector.yml +++ b/.github/workflows/secrets-detector.yml @@ -25,13 +25,24 @@ jobs: - name: Checkout repository uses: actions/checkout@v4 with: - path: ${{ github.run_id }} + # setup repository and ref for PRs, see + # https://github.com/EndBug/add-and-commit?tab=readme-ov-file#working-with-prs + repository: ${{ github.event.pull_request.head.repo.full_name }} + ref: ${{ github.event.pull_request.head.ref }} + # custom token is required to trigger actions after reformatting + pushing fetch-depth: 0 + token: ${{ secrets.NEMO_REFORMAT_TOKEN }} - name: Install secrets detector run: pip install detect-secrets - name: Run on change-set run: | - cd ${{ github.run_id }} - git diff --name-only --diff-filter=d --merge-base origin/main -z | xargs -0 detect-secrets-hook --baseline .secrets.baseline \ No newline at end of file + git diff --name-only --diff-filter=d --merge-base origin/main -z | xargs -0 detect-secrets-hook --baseline .secrets.baseline + + - uses: EndBug/add-and-commit@v9 + # Commit changes. Nothing is committed if no changes. + if: always() + with: + message: Update baseline + commit: --signoff From b26c220d636c65230d31e2df1357e405c455d913 Mon Sep 17 00:00:00 2001 From: Jan Lasek Date: Tue, 12 Nov 2024 18:56:20 +0100 Subject: [PATCH 08/24] Remove deprecated builder_opt param from build command (#11259) Signed-off-by: Jan Lasek --- nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index 7a1f7a6cc31d..eac1ab743849 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -77,8 +77,6 @@ def qnemo_to_tensorrt_llm( use_qdq = quant_algo in ["FP8", "W8A8_SQ_PER_CHANNEL"] - builder_opt = 4 if "RecurrentGemma" not in config.architecture else 0 - speculative_decoding_mode = "medusa" if "Medusa" in config.architecture else None build_cmd = "trtllm-build " @@ -90,7 +88,6 @@ def qnemo_to_tensorrt_llm( build_cmd += f"--max_input_len {max_input_len} " build_cmd += f"--max_beam_width {max_beam_width} " build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} " - build_cmd += f"--builder_opt {builder_opt} " build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} " build_cmd += f"--use_paged_context_fmha {'enable' if paged_context_fmha else 'disable'} " build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " From 098aa18ab99dfc8076cae9f612b9bde3c16ff365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 12 Nov 2024 20:41:59 +0100 Subject: [PATCH 09/24] =?UTF-8?q?chore(beep=20boop=20=F0=9F=A4=96):=20Bump?= =?UTF-8?q?=20`MCORE=5FTAG=3Daded519...`=20(2024-11-12)=20(#11260)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Dockerfile.ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index ee04c79cd2ba..5858f0aadf5b 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=47ff44e5b98061bf81295ce7df899ee62529d5e3 +ARG MCORE_TAG=aded519cfb1de2abf96f36ca059f992294b7876f ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ From 5670706f9549cc944eff45a97b259ca988d3cd7f Mon Sep 17 00:00:00 2001 From: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> Date: Tue, 12 Nov 2024 11:55:51 -0800 Subject: [PATCH 10/24] [Doc fixes] update file names, installation instructions, bad links (#11045) * rename eval_beamsearch_ngram.py to eval_beamsearch_ngram_ctc.py in docs Signed-off-by: Elena Rastorgueva * replace out of date installation instructions with pointer to NeMo README installation section Signed-off-by: Elena Rastorgueva * point to user guide instead of readme Signed-off-by: Elena Rastorgueva * some link updates Signed-off-by: Elena Rastorgueva * update more links Signed-off-by: Elena Rastorgueva --------- Signed-off-by: Elena Rastorgueva Signed-off-by: Elena Rastorgueva <80532067+erastorgueva-nv@users.noreply.github.com> --- ...sr_language_modeling_and_customization.rst | 34 +++---- docs/source/asr/intro.rst | 4 +- docs/source/core/core.rst | 4 +- docs/source/multimodal/mllm/configs.rst | 4 +- docs/source/multimodal/text2img/imagen.rst | 2 +- docs/source/multimodal/vlm/configs.rst | 2 +- docs/source/multimodal/vlm/datasets.rst | 2 +- .../nn_text_normalization.rst | 2 +- .../wfst/wfst_customization.rst | 4 +- docs/source/starthere/intro.rst | 89 +------------------ docs/source/tools/speech_data_explorer.rst | 10 --- docs/source/tts/datasets.rst | 2 +- docs/source/vision/configs.rst | 3 +- 13 files changed, 32 insertions(+), 130 deletions(-) diff --git a/docs/source/asr/asr_language_modeling_and_customization.rst b/docs/source/asr/asr_language_modeling_and_customization.rst index 0b4f7a7e730f..a9d42772698c 100644 --- a/docs/source/asr/asr_language_modeling_and_customization.rst +++ b/docs/source/asr/asr_language_modeling_and_customization.rst @@ -99,15 +99,15 @@ Evaluate by Beam Search Decoding and N-gram LM NeMo's beam search decoders are capable of using the KenLM's N-gram models to find the best candidates. The script to evaluate an ASR model with beam search decoding and N-gram models can be found at -`scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py `__. +`scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py `__. -This script has a large number of possible argument overrides; therefore, it is recommended that you use ``python eval_beamsearch_ngram.py --help`` to see the full list of arguments. +This script has a large number of possible argument overrides; therefore, it is recommended that you use ``python eval_beamsearch_ngram_ctc.py --help`` to see the full list of arguments. You can evaluate an ASR model using the following: .. code-block:: - python eval_beamsearch_ngram.py nemo_model_file= \ + python eval_beamsearch_ngram_ctc.py nemo_model_file= \ input_manifest= \ beam_width=[] \ @@ -118,18 +118,18 @@ You can evaluate an ASR model using the following: decoding_mode=beamsearch_ngram \ decoding_strategy="" -It can evaluate a model in the following three modes by setting the argument `--decoding_mode`: +It can evaluate a model in the following three modes by setting the argument ``--decoding_mode``: * greedy: Just greedy decoding is done and no beam search decoding is performed. * beamsearch: The beam search decoding is done, but without using the N-gram language model. Final results are equivalent to setting the weight of LM (beam_beta) to zero. * beamsearch_ngram: The beam search decoding is done with N-gram LM. -In `beamsearch` mode, the evaluation is performed using beam search decoding without any language model. The performance is reported in terms of Word Error Rate (WER) and Character Error Rate (CER). Moreover, when the best candidate is selected among the candidates, it is also reported as the best WER/CER. This can serve as an indicator of the quality of the predicted candidates. +In ``beamsearch`` mode, the evaluation is performed using beam search decoding without any language model. The performance is reported in terms of Word Error Rate (WER) and Character Error Rate (CER). Moreover, when the best candidate is selected among the candidates, it is also reported as the best WER/CER. This can serve as an indicator of the quality of the predicted candidates. The script initially loads the ASR model and predicts the outputs of the model's encoder as log probabilities. This part is computed in batches on a device specified by --device, which can be either a CPU (`--device=cpu`) or a single GPU (`--device=cuda:0`). -The batch size for this part is specified by `--acoustic_batch_size`. Using the largest feasible batch size can speed up the calculation of log probabilities. Additionally, you can use `--use_amp` to accelerate the calculation and allow for larger --acoustic_batch_size values. -Currently, multi-GPU support is not available for calculating log probabilities. However, using `--probs_cache_file` can help. This option stores the log probabilities produced by the model’s encoder in a pickle file, allowing you to skip the first step in future runs. +The batch size for this part is specified by ``--acoustic_batch_size``. Using the largest feasible batch size can speed up the calculation of log probabilities. Additionally, you can use `--use_amp` to accelerate the calculation and allow for larger --acoustic_batch_size values. +Currently, multi-GPU support is not available for calculating log probabilities. However, using ``--probs_cache_file`` can help. This option stores the log probabilities produced by the model’s encoder in a pickle file, allowing you to skip the first step in future runs. The following is the list of the important arguments for the evaluation script: @@ -167,7 +167,7 @@ The following is the list of the important arguments for the evaluation script: | decoding_strategy | str | beam | String argument for type of decoding strategy for the model. | +--------------------------------------+----------+------------------+-------------------------------------------------------------------------+ | decoding | Dict | BeamCTC | Subdict of beam search configs. Values found via | -| | Config | InferConfig | python eval_beamsearch_ngram.py --help | +| | Config | InferConfig | python eval_beamsearch_ngram_ctc.py --help | +--------------------------------------+----------+------------------+-------------------------------------------------------------------------+ | text_processing.do_lowercase | bool | ``False`` | Whether to make the training text all lower case. | +--------------------------------------+----------+------------------+-------------------------------------------------------------------------+ @@ -178,11 +178,11 @@ The following is the list of the important arguments for the evaluation script: | text_processing.separate_punctuation | bool | ``True`` | Whether to separate punctuation with the previous word by space. | +--------------------------------------+----------+------------------+-------------------------------------------------------------------------+ -The width of the beam search (`--beam_width`) specifies the number of top candidates or predictions the beam search decoder will consider. Larger beam widths result in more accurate but slower predictions. +The width of the beam search (``--beam_width``) specifies the number of top candidates or predictions the beam search decoder will consider. Larger beam widths result in more accurate but slower predictions. .. note:: - The ``eval_beamsearch_ngram.py`` script contains the entire subconfig used for CTC Beam Decoding. + The ``eval_beamsearch_ngram_ctc.py`` script contains the entire subconfig used for CTC Beam Decoding. Therefore it is possible to forward arguments for various beam search libraries such as ``flashlight`` and ``pyctcdecode`` via the ``decoding`` subconfig. @@ -223,14 +223,14 @@ It supports several advanced features, such as lexicon-based decoding, lexicon-f .. code-block:: # Lexicon-based decoding - python eval_beamsearch_ngram.py ... \ + python eval_beamsearch_ngram_ctc.py ... \ decoding_strategy="flashlight" \ decoding.beam.flashlight_cfg.lexicon_path='/path/to/lexicon.lexicon' \ decoding.beam.flashlight_cfg.beam_size_token = 32 \ decoding.beam.flashlight_cfg.beam_threshold = 25.0 # Lexicon-free decoding - python eval_beamsearch_ngram.py ... \ + python eval_beamsearch_ngram_ctc.py ... \ decoding_strategy="flashlight" \ decoding.beam.flashlight_cfg.beam_size_token = 32 \ decoding.beam.flashlight_cfg.beam_threshold = 25.0 @@ -256,7 +256,7 @@ It has advanced features, such as word boosting, which can be useful for transcr .. code-block:: # PyCTCDecoding - python eval_beamsearch_ngram.py ... \ + python eval_beamsearch_ngram_ctc.py ... \ decoding_strategy="pyctcdecode" \ decoding.beam.pyctcdecode_cfg.beam_prune_logp = -10. \ decoding.beam.pyctcdecode_cfg.token_min_logp = -5. \ @@ -273,7 +273,7 @@ For example, the following set of parameters would result in 212=4 beam search d .. code-block:: - python eval_beamsearch_ngram.py ... \ + python eval_beamsearch_ngram_ctc.py ... \ beam_width=[64,128] \ beam_alpha=[1.0] \ beam_beta=[1.0,0.5] @@ -330,7 +330,7 @@ Given a trained TransformerLMModel `.nemo` file or a pretrained HF model, the sc can be used to re-score beams obtained with ASR model. You need the `.tsv` file containing the candidates produced by the acoustic model and the beam search decoding to use this script. The candidates can be the result of just the beam search decoding or the result of fusion with an N-gram LM. You can generate this file by specifying `--preds_output_folder` for -`scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py `__. +`scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py `__. The neural rescorer would rescore the beams/candidates by using two parameters of `rescorer_alpha` and `rescorer_beta`, as follows: @@ -345,7 +345,7 @@ Use the following steps to evaluate a neural LM: #. Obtain `.tsv` file with beams and their corresponding scores. Scores can be from a regular beam search decoder or in fusion with an N-gram LM scores. For a given beam size `beam_size` and a number of examples for evaluation `num_eval_examples`, it should contain (`num_eval_examples` x `beam_size`) lines of - form `beam_candidate_text \t score`. This file can be generated by `scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram.py `__ + form `beam_candidate_text \t score`. This file can be generated by `scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py `__ #. Rescore the candidates by `scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py `__. @@ -439,7 +439,7 @@ You can then pass this file to your Flashlight config object during decoding: .. code-block:: # Lexicon-based decoding - python eval_beamsearch_ngram.py ... \ + python eval_beamsearch_ngram_ctc.py ... \ decoding_strategy="flashlight" \ decoding.beam.flashlight_cfg.lexicon_path='/path/to/lexicon.lexicon' \ decoding.beam.flashlight_cfg.boost_path='/path/to/my_boost_file.boost' \ diff --git a/docs/source/asr/intro.rst b/docs/source/asr/intro.rst index ade767e541a0..7303d1698422 100644 --- a/docs/source/asr/intro.rst +++ b/docs/source/asr/intro.rst @@ -127,8 +127,8 @@ You can get a good improvement in transcription accuracy even using a simple N-g After :ref:`training ` an N-gram LM, you can use it for transcribing audio as follows: -1. Install the OpenSeq2Seq beam search decoding and KenLM libraries using the `install_beamsearch_decoders script `_. -2. Perform transcription using the `eval_beamsearch_ngram script `_: +1. Install the OpenSeq2Seq beam search decoding and KenLM libraries using the `install_beamsearch_decoders script `_. +2. Perform transcription using the `eval_beamsearch_ngram script `_: .. code-block:: bash diff --git a/docs/source/core/core.rst b/docs/source/core/core.rst index 6bdd18559902..94706b639b5f 100644 --- a/docs/source/core/core.rst +++ b/docs/source/core/core.rst @@ -294,8 +294,8 @@ CLI With NeMo and Hydra, every aspect of model training can be modified from the command-line. This is extremely helpful for running lots of experiments on compute clusters or for quickly testing parameters during development. -All NeMo `examples `_ come with instructions on how to -run the training/inference script from the command-line (see `here `__ +All NeMo `examples `_ come with instructions on how to +run the training/inference script from the command-line (e.g. see `here `__ for an example). With Hydra, arguments are set using the ``=`` operator: diff --git a/docs/source/multimodal/mllm/configs.rst b/docs/source/multimodal/mllm/configs.rst index 6e9f9b2b8d10..53b851867f65 100644 --- a/docs/source/multimodal/mllm/configs.rst +++ b/docs/source/multimodal/mllm/configs.rst @@ -5,14 +5,14 @@ This section provides a detailed overview of the NeMo configuration file setup s Within the configuration files of the NeMo Multimodal Language Model, details concerning dataset(s), augmentation, optimization parameters, and model architectural specifications are central. This page explores each of these aspects. -Discover exemplary configuration files for all NeMo Multimodal Language Model scripts in the `config directory of the examples `_. +Discover exemplary configuration files for all NeMo Multimodal Language Model scripts in the `config directory of the examples `_. Dataset Configuration --------------------- The NeMo multimodal language model currently supports a conversation data format, inspired by and designed from https://github.com/haotian-liu/LLaVA/tree/main. To explore a sample dataset, visit https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md. -The configuration file allows setting any initialization parameter accepted by the Dataset class used in the experiment. For a comprehensive list of Datasets and their parameters, visit the `Datasets <./api.html#Datasets>`__ section of the API. +The configuration file allows setting any initialization parameter accepted by the Dataset class used in the experiment. For a comprehensive list of Datasets and their parameters, visit the :doc:`Datasets <./datasets>` section of the API. A typical training configuration is as follows: diff --git a/docs/source/multimodal/text2img/imagen.rst b/docs/source/multimodal/text2img/imagen.rst index 844f68df747f..3134ffdc2747 100644 --- a/docs/source/multimodal/text2img/imagen.rst +++ b/docs/source/multimodal/text2img/imagen.rst @@ -31,7 +31,7 @@ Imagen has two types of UNet: Regular UNet and EfficientUNet. Regular UNet ~~~~~~~~~~~~ Regular UNet is used for Imagen base64 model. You can also use regular UNet for SR models -(see example config file `sr256-400m-edm.yaml `_), but this typically +(see example config file `sr256-400m-edm.yaml `__), but this typically results in a larger memory footprint during training for the same model size. Recommended UNet size for base64 and SR256 models are listed below: diff --git a/docs/source/multimodal/vlm/configs.rst b/docs/source/multimodal/vlm/configs.rst index cc383cb64b62..711831121bd7 100644 --- a/docs/source/multimodal/vlm/configs.rst +++ b/docs/source/multimodal/vlm/configs.rst @@ -5,7 +5,7 @@ This section provides a detailed overview of the NeMo configuration file setup s Within the configuration files of the NeMo Multimodal Language Model, details concerning dataset(s), augmentation, optimization parameters, and model architectural specifications are central. This page explores each of these aspects. -Discover exemplary configuration files for all NeMo Multimodal Language Model scripts in the `config directory of the examples `_. +Discover exemplary configuration files for all NeMo Multimodal Language Model scripts in the `config directories of the examples `__. Dataset Configuration ===================== diff --git a/docs/source/multimodal/vlm/datasets.rst b/docs/source/multimodal/vlm/datasets.rst index 057c79109b08..0c32210d8b6f 100644 --- a/docs/source/multimodal/vlm/datasets.rst +++ b/docs/source/multimodal/vlm/datasets.rst @@ -32,4 +32,4 @@ For webdatasets already downloaded locally, sub-stages 4-6 can be used to precac For models that encode image and text on-the-fly, only sub-stages 1-3 need to be run. Instruction for configuring each sub-stage is provided as a comment next to each field in -`download_multimodal.yaml `_ +`download_multimodal.yaml `__. diff --git a/docs/source/nlp/text_normalization/nn_text_normalization.rst b/docs/source/nlp/text_normalization/nn_text_normalization.rst index 87530dbcbc29..d4c172a4fab0 100644 --- a/docs/source/nlp/text_normalization/nn_text_normalization.rst +++ b/docs/source/nlp/text_normalization/nn_text_normalization.rst @@ -87,7 +87,7 @@ Data upsampling --------------- Data upsampling is an effective way to increase the training data for better model performance, especially on the long tail of semiotic tokens. -We used upsampling for training an English text normalization model, see `data/en/upsampling.py `__. +We used upsampling for training an English text normalization model, see `data/en/upsampling.py `__. Currently this script only upsamples a few classes, that are diverse in semiotic tokens but at the same time underrepresented in the training data. Of all the input files in `train` folder created by `data/data_split.py `__. this script takes the first file and detects the class patterns that occur in it. For those that are underrepresented, quantitatively defined as lower than `min_number`, the other files are scanned for sentences that have the missing patterns. diff --git a/docs/source/nlp/text_normalization/wfst/wfst_customization.rst b/docs/source/nlp/text_normalization/wfst/wfst_customization.rst index a199c1fb09d0..4af157489480 100644 --- a/docs/source/nlp/text_normalization/wfst/wfst_customization.rst +++ b/docs/source/nlp/text_normalization/wfst/wfst_customization.rst @@ -38,5 +38,5 @@ WFST TN/ITN resources could be found in :doc:`here `. Riva resources -------------- - - `Riva Text Normalization customization for TTS `_. - - `Riva ASR/Inverse Text Normalization customization `_. \ No newline at end of file + - `Riva Text Normalization customization for TTS `_. + - `Riva ASR/Inverse Text Normalization customization `_. \ No newline at end of file diff --git a/docs/source/starthere/intro.rst b/docs/source/starthere/intro.rst index 0cf7146ff1ef..41c10dd5c6ea 100644 --- a/docs/source/starthere/intro.rst +++ b/docs/source/starthere/intro.rst @@ -32,95 +32,8 @@ Before using NeMo, make sure you meet the following prerequisites: Installation ------------ -**Using NVIDIA PyTorch Container** +Refer to the NeMo Framework `User Guide `__ for the latest installation instructions. -To leverage all optimizations for LLM training, including 3D Model Parallel, fused kernels, FP8, and more, we recommend using the NVIDIA PyTorch container. - -.. code-block:: bash - - docker pull nvcr.io/nvidia/pytorch:24.01-py3 - docker run --gpus all -it nvcr.io/nvidia/pytorch:24.01-py3 - -Within the container, you can install NeMo and its dependencies as follows: - -NeMo Installation - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - pip install Cython - pip install nemo_toolkit['all'] - -Transformer Engine Installation - -This step involves cloning the Transformer Engine repository, checking out a specific commit, and installing it with specific flags. - -.. code-block:: bash - - git clone https://github.com/NVIDIA/TransformerEngine.git && \ - cd TransformerEngine && \ - git fetch origin 8c9abbb80dba196f086b8b602a7cf1bce0040a6a && \ - git checkout FETCH_HEAD && \ - git submodule init && git submodule update && \ - NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . - -Apex Installation - -This step includes a bug fix for Apex in the PyTorch 23.11 container. - -.. code-block:: bash - - git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - git checkout c07a4cf67102b9cd3f97d1ba36690f985bae4227 && \ - cp -R apex /usr/local/lib/python3.10/dist-packages - -PyTorch Lightning Installation - -This step involves installing a bug-fixed version of PyTorch Lightning from a specific branch. - -.. code-block:: bash - - git clone -b bug_fix https://github.com/athitten/pytorch-lightning.git && \ - cd pytorch-lightning && \ - PACKAGE_NAME=pytorch pip install -e . - -Megatron Core Installation - -This section details the steps to clone and install the Megatron Core. - -.. code-block:: bash - - git clone https://github.com/NVIDIA/Megatron-LM.git && \ - cd Megatron-LM && \ - git checkout a5415fcfacef2a37416259bd38b7c4b673583675 && \ - pip install . - -TensorRT Model Optimizer Installation - -This final step involves installing the TensorRT Model Optimizer package. - -.. code-block:: bash - - pip install nvidia-modelopt[torch]~=0.19.0 --extra-index-url https://pypi.nvidia.com - - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - pip install Cython - pip install nemo_toolkit['all'] - -**Conda Installation** - -If you do not use the NVIDIA PyTorch container, we recommend installing NeMo in a clean Conda environment. - -.. code-block:: bash - - conda create --name nemo python==3.10.12 - conda activate nemo - -Refer to the PyTorch configurator for instructions on installing PyTorch. `configurator `_ Quick Start Guide ----------------- diff --git a/docs/source/tools/speech_data_explorer.rst b/docs/source/tools/speech_data_explorer.rst index a57cb442f468..ac13f3936746 100644 --- a/docs/source/tools/speech_data_explorer.rst +++ b/docs/source/tools/speech_data_explorer.rst @@ -18,16 +18,6 @@ Speech Data Explorer (SDE) is a `Dash `__-based web ap | estimation of audio signal parameters [peak level, frequency bandwidth] | +--------------------------------------------------------------------------------------------------------------------------+ -SDE Demo Instance ------------------ - -To demonstrate both the :doc:`CTC-Segmentation <./ctc_segmentation>` and Speech Data Explorer tools, we re-segmenting the development set as of `the LibriSpeech corpus `_. -We concatenated all audio files from the dev-clean split into a single file and set up the CTC-Segmentation tool to cut the long audio file into original utterances. -We used the CTC-based `QuartzNet15x5Base-En ASR model `_. -The segmented corpus has 3.82% WER and contains 300 out of the initial 323 minutes of audio. -The remaining 23 minutes are the silence at the beginning and end of the audio removed during the segmentation. -A `running instance of the SDE `_ demonstrates the re-segmented corpus. - Getting Started --------------- SDE could be found in `NeMo/tools/speech_data_explorer `__. diff --git a/docs/source/tts/datasets.rst b/docs/source/tts/datasets.rst index 7efe116dcccc..e37c2176c41a 100644 --- a/docs/source/tts/datasets.rst +++ b/docs/source/tts/datasets.rst @@ -141,7 +141,7 @@ There are two German neutral datasets released by Thorsten Müller for now, 21.0 HUI Audio Corpus German ~~~~~~~~~~~~~~~~~~~~~~~ -* Dataset URL: https://opendata.iisys.de/datasets.html +* Dataset URL: https://github.com/iisys-hof/HUI-Audio-Corpus-German * Dataset Processing Script: https://github.com/NVIDIA/NeMo/tree/stable/scripts/dataset_processing/tts/hui_acg/get_data.py * Command Line Instruction: diff --git a/docs/source/vision/configs.rst b/docs/source/vision/configs.rst index 92b7e5b45d12..1064f9569c3d 100644 --- a/docs/source/vision/configs.rst +++ b/docs/source/vision/configs.rst @@ -4,7 +4,7 @@ This section provides a detailed overview of the NeMo configuration file setup s Within the configuration files of the NeMo vision models, details concerning dataset(s), augmentation, optimization parameters, and model architectural specifications are central. This page explores each of these aspects. -Discover exemplary configuration files for all NeMo vision models scripts in the `config directory of the examples `_. +Discover exemplary configuration files for all NeMo vision models scripts in the `config directory of the examples `__. Dataset Configuration ===================== @@ -12,7 +12,6 @@ Dataset Configuration The configuration file delineates parameters for dataset path. All initialization parameters supported by the Dataset class utilized in the experiment can be defined in the config file. -.. For a comprehensive list of Datasets and their associated parameters, consult the `Datasets <./api.html#Datasets>`__ section of the API. A representative training configuration appears as: From 2d4f4953881b9e2d118d3ffeba7e64625d827d11 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 12 Nov 2024 12:22:53 -0800 Subject: [PATCH 11/24] fix(export): GPT models w/ bias=False convert properly (#11255) Signed-off-by: Terry Kong --- nemo/export/trt_llm/tensorrt_llm_build.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index cdf8eaac6b1c..4be2d42ebe4d 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -118,6 +118,14 @@ def build_and_save_engine( build_config.lora_config = lora_config model = model_cls.from_config(model_config) + if not model_config.bias and model_config.architecture == 'GPTForCausalLM': + # NOTE: GPT models in megatron-core that set bias=False sets the bias false globally + # whereas bias=False in TRTLLM GPT models sets it false everywhere except + # LayerNorm. This change makes TRTLLM's implementation match megatron-core. + for name, module in model.named_modules(): + if isinstance(module, tensorrt_llm.layers.normalization.LayerNorm): + module.bias = None + module.register_parameter('bias', None) model = optimize_model( model, use_parallel_embedding=model_config.use_parallel_embedding, From 24e28716bf658ed62c99759aea9bc5db936898fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 12 Nov 2024 21:51:06 +0100 Subject: [PATCH 12/24] ci: Run secrets detector on `pull_request_target` (#11263) Signed-off-by: Oliver Koenig --- .github/workflows/secrets-detector.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/secrets-detector.yml b/.github/workflows/secrets-detector.yml index d81b5638e31f..825ae7a653fc 100644 --- a/.github/workflows/secrets-detector.yml +++ b/.github/workflows/secrets-detector.yml @@ -14,7 +14,7 @@ name: Secrets detector on: - pull_request: + pull_request_target: branches: - 'main' From 085e957f460bbc083479f01bf14acb2fae6a5b1d Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Tue, 12 Nov 2024 13:29:56 -0800 Subject: [PATCH 13/24] fix(export): update API for disabling device reassignment in TRTLLM for Aligner (#10863) * fix(export): update API for disabling device reassignment in TRTLLM for Aligner [feat] Upgrade nemo-export path for aligner to TRTLLM-v12 and use python runtime Signed-off-by: Terry Kong fix: forgot to always set _disable_torch_cuda_device_set Signed-off-by: Terry Kong Signed-off-by: Terry Kong Apply isort and black reformatting Signed-off-by: terrykong invert torch device set Signed-off-by: Terry Kong * remove comment Signed-off-by: Terry Kong --------- Signed-off-by: Terry Kong --- nemo/export/trt_llm/tensorrt_llm_run.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index 1772c071a745..bd7b8abd5f9e 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -32,17 +32,23 @@ from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.mapping import Mapping from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig from transformers import PreTrainedTokenizer LOGGER = logging.getLogger("NeMo") use_trtllm_bindings = True try: - from tensorrt_llm.bindings import GptJsonConfig, KvCacheConfig, WorldConfig + from tensorrt_llm.bindings import GptJsonConfig except Exception as e: use_trtllm_bindings = False +TRTLLM_SUPPORTS_DEVICE_DISABLE = True +try: + from tensorrt_llm.runtime.generation import DISABLE_TORCH_DEVICE_SET +except (ImportError, ModuleNotFoundError): + TRTLLM_SUPPORTS_DEVICE_DISABLE = False + @dataclass class TensorrtLLMHostContext: @@ -494,12 +500,20 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node): json_config_str = f.read() engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank) + + if not TRTLLM_SUPPORTS_DEVICE_DISABLE: + raise RuntimeError( + f"TensorRT-LLM does not support torch device disabling. Please upgrade TensorRT-LLM to make use of this feature." + ) + elif not DISABLE_TORCH_DEVICE_SET: + raise RuntimeError( + f"To use TensorRT-LLM's python ModelRunner API in load_distributed(...) you must set the env var DISABLE_TORCH_DEVICE_SET=1" + ) decoder = ModelRunner.from_engine( engine=engine, # We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process # So we will set it to the current device rank=torch.cuda.current_device(), - _disable_torch_cuda_device_set=True, ) tensorrt_llm_worker_context.decoder = decoder From 6e8e97499cb029de92ba1583221d4d1eea733886 Mon Sep 17 00:00:00 2001 From: Zeeshan Patel Date: Tue, 12 Nov 2024 20:09:31 -0800 Subject: [PATCH 14/24] new vfm training features (#11246) Signed-off-by: Zeeshan Patel Co-authored-by: Zeeshan Patel --- .../data/diffusion_energon_datamodule.py | 58 ++- .../data/diffusion_fake_datamodule.py | 218 +++++++++++ .../diffusion/data/diffusion_taskencoder.py | 260 ++++++++++++- .../diffusion/models/dit/dit_embeddings.py | 31 +- .../diffusion/models/dit/dit_layer_spec.py | 4 +- .../diffusion/models/dit/dit_model.py | 96 +++-- .../models/dit_llama/dit_llama_layer_spec.py | 70 +++- .../models/dit_llama/dit_llama_model.py | 6 +- nemo/collections/diffusion/models/model.py | 97 ++++- .../diffusion/sampler/edm/edm_pipeline.py | 9 +- nemo/collections/diffusion/train.py | 362 +++++++++++++++++- 11 files changed, 1097 insertions(+), 114 deletions(-) create mode 100644 nemo/collections/diffusion/data/diffusion_fake_datamodule.py diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index f18c828d9d45..67a26609dd51 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -15,7 +15,8 @@ import logging from typing import Any, Dict, Literal -from megatron.energon import DefaultTaskEncoder, get_train_dataset +from megatron.core import parallel_state +from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset from pytorch_lightning.utilities.types import EVAL_DATALOADERS from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule @@ -56,6 +57,9 @@ def __init__( pin_memory: bool = True, task_encoder: DefaultTaskEncoder = None, use_train_split_for_val: bool = False, + virtual_epoch_length: int = 1_000_000_000, # a hack to avoid energon end of epoch warning + packing_buffer_size: int | None = None, + max_samples_per_sequence: int | None = None, ) -> None: """ Initialize the SimpleMultiModalDataModule. @@ -82,6 +86,10 @@ def __init__( task_encoder=task_encoder, ) self.use_train_split_for_val = use_train_split_for_val + self.virtual_epoch_length = virtual_epoch_length + self.num_workers_val = 1 + self.packing_buffer_size = packing_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): """ @@ -106,29 +114,55 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val batch_size=self.micro_batch_size, task_encoder=self.task_encoder, worker_config=worker_config, - max_samples_per_sequence=None, - shuffle_buffer_size=100, + max_samples_per_sequence=self.max_samples_per_sequence, + shuffle_buffer_size=None, split_part=split, - batch_drop_last=True, - virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + virtual_epoch_length=self.virtual_epoch_length, + packing_buffer_size=self.packing_buffer_size, ) return _dataset def val_dataloader(self) -> EVAL_DATALOADERS: """ - Configure the validation DataLoader. + Initialize and return the validation DataLoader. - This method configures the DataLoader for validation data. - - Parameters: - worker_config: Configuration for the data loader workers. + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. Returns: - DataLoader: The DataLoader for validation data. + EVAL_DATALOADERS: The DataLoader for the validation dataset. """ if self.use_train_split_for_val: return self.train_dataloader() - return super().val_dataloader() + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + message = ( + "Muiltimodal val data loader parallel state is not initialized " + f"using default worker config with no_workers {self.num_workers}" + ) + logging.info(message) + + worker_config = WorkerConfig.default_worker_config(self.num_workers_val) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers_val, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split='val') + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ diff --git a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py new file mode 100644 index 000000000000..6cb686c1c305 --- /dev/null +++ b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py @@ -0,0 +1,218 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils.data import DataLoader + +from nemo.collections.diffusion.models.model import DiTConfig +from nemo.lightning.pytorch.plugins import MegatronDataSampler + +from .diffusion_taskencoder import pos_id_3d + + +class PosEmb3D: + """Generates and provides 3D positional embeddings for video data.""" + + def __init__(self, *, max_t=96, max_h=960, max_w=960): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + """Generates the positional ID grid based on max_t, max_h, and max_w.""" + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device='cpu'), + torch.arange(self.max_h, device='cpu'), + torch.arange(self.max_w, device='cpu'), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + """Retrieves a subset of the positional IDs for the specified dimensions. + + Parameters: + t (int): Number of time frames. + h (int): Height dimension. + w (int): Width dimension. + + Returns: + torch.Tensor: The positional IDs tensor with shape (t, h, w, 3). + """ + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +class DiTVideoLatentFakeDataset(torch.utils.data.Dataset): + """A fake dataset for generating synthetic video latent data.""" + + def __init__( + self, + n_frames, + max_h, + max_w, + patch_size, + in_channels, + crossattn_emb_size, + max_text_seqlen=512, + seq_length=8192, + ): + self.max_t = n_frames + self.max_height = max_h + self.max_width = max_w + self.patch_size = patch_size + self.in_channels = in_channels + self.text_dim = crossattn_emb_size + self.text_seqlen = max_text_seqlen + self.seq_length = seq_length + + def __len__(self): + """Returns the total number of samples.""" + return 100000000 + + def __getitem__(self, idx): + """Generates a single sample of data. + + Parameters: + idx (int): Index of the data sample. + + Returns: + dict: A dictionary containing video latent data and related information. + """ + t = self.max_t + h = self.max_height + w = self.max_width + p = self.patch_size + c = self.in_channels + + video_latent = torch.ones(self.seq_length, c * p**2, dtype=torch.bfloat16) * 0.5 + text_embedding = torch.randn(self.text_seqlen, self.text_dim, dtype=torch.bfloat16) + pos_emb = pos_id_3d.get_pos_id_3d(t=t, h=h // p, w=w // p).reshape(-1, 3) + + return { + 'video': video_latent, + 't5_text_embeddings': text_embedding, + 'seq_len_q': torch.tensor([video_latent.shape[0]], dtype=torch.int32).squeeze(), + 'seq_len_kv': torch.tensor([self.text_seqlen], dtype=torch.int32).squeeze(), + 'pos_ids': torch.zeros((self.seq_length, 3), dtype=torch.int32), + 'loss_mask': torch.ones(video_latent.shape[0], dtype=torch.bfloat16), + } + + def _collate_fn(self, batch): + """A default implementation of a collation function. + + Users should override this method to define custom data loaders. + """ + return torch.utils.data.dataloader.default_collate(batch) + + def collate_fn(self, batch): + """Method that user passes as a functor to DataLoader. + + The method optionally performs neural type checking and adds types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns: + Collated batch, with or without types. + """ + return self._collate_fn(batch) + + +class VideoLatentFakeDataModule(pl.LightningDataModule): + """A LightningDataModule for generating fake video latent data for training.""" + + def __init__( + self, + model_config: DiTConfig, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 8, + num_workers: int = 1, + pin_memory: bool = True, + task_encoder=None, + use_train_split_for_val: bool = False, + ) -> None: + super().__init__() + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.model_config = model_config + + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + ) + + def setup(self, stage: str = "") -> None: + """Sets up the dataset for training and validation. + + Parameters: + stage (str): Optional stage argument (unused). + """ + self._train_ds = DiTVideoLatentFakeDataset( + n_frames=self.model_config.max_frames, + max_h=self.model_config.max_img_h, + max_w=self.model_config.max_img_w, + patch_size=self.model_config.patch_spatial, + in_channels=self.model_config.in_channels, + crossattn_emb_size=self.model_config.crossattn_emb_size, + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """Returns the training DataLoader.""" + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """Returns the validation DataLoader.""" + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """Creates a DataLoader for the given dataset. + + Parameters: + dataset (Dataset): The dataset to load. + **kwargs: Additional arguments for DataLoader. + + Returns: + DataLoader: The DataLoader instance. + """ + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=True, + persistent_workers=True, + collate_fn=dataset.collate_fn, + **kwargs, + ) diff --git a/nemo/collections/diffusion/data/diffusion_taskencoder.py b/nemo/collections/diffusion/data/diffusion_taskencoder.py index 57e4e4ec8673..2a42b15453b3 100644 --- a/nemo/collections/diffusion/data/diffusion_taskencoder.py +++ b/nemo/collections/diffusion/data/diffusion_taskencoder.py @@ -12,15 +12,96 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings +import random +from dataclasses import dataclass +from typing import Any, List, Optional + import torch import torch.nn.functional as F from einops import rearrange -from megatron.core import parallel_state -from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon import DefaultTaskEncoder, Sample, SkipSample +from megatron.energon.task_encoder.base import stateless from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from nemo.lightning.io.mixin import IOMixin +from nemo.utils.sequence_packing_utils import first_fit_decreasing + + +@dataclass +class DiffusionSample(Sample): + """ + Data class representing a sample for diffusion tasks. + + Attributes: + video (torch.Tensor): Video latents (C T H W). + t5_text_embeddings (torch.Tensor): Text embeddings (S D). + t5_text_mask (torch.Tensor): Mask for text embeddings. + loss_mask (torch.Tensor): Mask indicating valid positions for loss computation. + image_size (Optional[torch.Tensor]): Tensor containing image dimensions. + fps (Optional[torch.Tensor]): Frame rate of the video. + num_frames (Optional[torch.Tensor]): Number of frames in the video. + padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. + seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. + seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. + pos_ids (Optional[torch.Tensor]): Positional IDs. + latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. + """ + + video: torch.Tensor # video latents (C T H W) + t5_text_embeddings: torch.Tensor # (S D) + t5_text_mask: torch.Tensor # 1 + loss_mask: torch.Tensor + image_size: Optional[torch.Tensor] = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + padding_mask: Optional[torch.Tensor] = None + seq_len_q: Optional[torch.Tensor] = None + seq_len_kv: Optional[torch.Tensor] = None + pos_ids: Optional[torch.Tensor] = None + latent_shape: Optional[torch.Tensor] = None + + def to_dict(self) -> dict: + """Converts the sample to a dictionary.""" + return dict( + video=self.video, + t5_text_embeddings=self.t5_text_embeddings, + t5_text_mask=self.t5_text_mask, + loss_mask=self.loss_mask, + image_size=self.image_size, + fps=self.fps, + num_frames=self.num_frames, + padding_mask=self.padding_mask, + seq_len_q=self.seq_len_q, + seq_len_kv=self.seq_len_kv, + pos_ids=self.pos_ids, + latent_shape=self.latent_shape, + ) + + def __add__(self, other: Any) -> int: + """Adds the sequence length of this sample with another sample or integer.""" + if isinstance(other, DiffusionSample): + # Combine the values of the two instances + return self.seq_len_q.item() + other.seq_len_q.item() + elif isinstance(other, int): + # Add an integer to the value + return self.seq_len_q.item() + other + raise NotImplementedError + + def __radd__(self, other: Any) -> int: + """Handles reverse addition for summing with integers.""" + # This is called if sum or other operations start with a non-DiffusionSample object. + # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__. + if isinstance(other, int): + return self.seq_len_q.item() + other + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + """Compares this sample's sequence length with another sample or integer.""" + if isinstance(other, DiffusionSample): + return self.seq_len_q.item() < other.seq_len_q.item() + elif isinstance(other, int): + return self.seq_len_q.item() < other + raise NotImplementedError def cook(sample: dict) -> dict: @@ -75,18 +156,26 @@ def __init__( max_frames: int = None, text_embedding_padding_size: int = 512, seq_length: int = None, + max_seq_length: int = None, patch_spatial: int = 2, patch_temporal: int = 1, + aesthetic_score: float = 0.0, **kwargs, ): super().__init__(*args, **kwargs) self.max_frames = max_frames self.text_embedding_padding_size = text_embedding_padding_size self.seq_length = seq_length + self.max_seq_length = max_seq_length self.patch_spatial = patch_spatial self.patch_temporal = patch_temporal + self.aesthetic_score = aesthetic_score + @stateless(restore_seeds=True) def encode_sample(self, sample: dict) -> dict: + """ + Encodes video / text sample. + """ video_latent = sample['pth'] if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): @@ -95,6 +184,9 @@ def encode_sample(self, sample: dict) -> dict: raise SkipSample() info = sample['json'] + if info['aesthetic_score'] < self.aesthetic_score: + raise SkipSample() + C, T, H, W = video_latent.shape seq_len = ( video_latent.shape[-1] @@ -105,19 +197,14 @@ def encode_sample(self, sample: dict) -> dict: ) is_image = T == 1 - if seq_len > self.seq_length: + if self.seq_length is not None and seq_len > self.seq_length: + raise SkipSample() + if self.max_seq_length is not None and seq_len > self.max_seq_length: raise SkipSample() if self.max_frames is not None: video_latent = video_latent[:, : self.max_frames, :, :] - tpcp_size = parallel_state.get_tensor_model_parallel_world_size() - if parallel_state.get_context_parallel_world_size() > 1: - tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 - if (T * H * W) % tpcp_size != 0: - warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') - raise SkipSample() - video_latent = rearrange( video_latent, 'C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)', @@ -161,7 +248,7 @@ def encode_sample(self, sample: dict) -> dict: 'T H W d -> (T H W) d', ) - if self.seq_length is not None: + if self.seq_length is not None and self.max_seq_length is None: pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) loss_mask[:seq_len] = 1 @@ -169,7 +256,11 @@ def encode_sample(self, sample: dict) -> dict: else: loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) - return dict( + return DiffusionSample( + __key__=sample['__key__'], + __restore_key__=sample['__restore_key__'], + __subflavor__=None, + __subflavors__=sample['__subflavors__'], video=video_latent, t5_text_embeddings=t5_text_embeddings, t5_text_mask=t5_text_mask, @@ -178,13 +269,86 @@ def encode_sample(self, sample: dict) -> dict: num_frames=num_frames, loss_mask=loss_mask, seq_len_q=torch.tensor(seq_len, dtype=torch.int32), - seq_len_kv=torch.tensor(t5_text_embeddings_seq_length, dtype=torch.int32), + seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), pos_ids=pos_ids, latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), ) + def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: + """ + Selects sequences to pack for mixed image-video training. + """ + results = first_fit_decreasing(samples, self.max_seq_length) + random.shuffle(results) + return results + + @stateless + def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: + """Construct a new Diffusion sample by concatenating the sequences.""" + + def stack(attr): + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + + def cat(attr): + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + + video = concat_pad([i.video for i in samples], self.max_seq_length) + loss_mask = concat_pad([i.loss_mask for i in samples], self.max_seq_length) + pos_ids = concat_pad([i.pos_ids for i in samples], self.max_seq_length) + + return DiffusionSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=video, + t5_text_embeddings=cat('t5_text_embeddings'), + t5_text_mask=cat('t5_text_mask'), + # image_size=stack('image_size'), + # fps=stack('fps'), + # num_frames=stack('num_frames'), + loss_mask=loss_mask, + seq_len_q=stack('seq_len_q'), + seq_len_kv=stack('seq_len_kv'), + pos_ids=pos_ids, + latent_shape=stack('latent_shape'), + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + """Return dictionary with data for batch.""" + if self.max_seq_length is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + return dict( + video=sample.video.unsqueeze_(0), + t5_text_embeddings=sample.t5_text_embeddings.unsqueeze_(0), + t5_text_mask=sample.t5_text_mask.unsqueeze_(0), + loss_mask=sample.loss_mask.unsqueeze_(0), + # image_size=sample.image_size, + # fps=sample.fps, + # num_frames=sample.num_frames, + # padding_mask=sample.padding_mask.unsqueeze_(0), + seq_len_q=sample.seq_len_q, + seq_len_kv=sample.seq_len_kv, + pos_ids=sample.pos_ids.unsqueeze_(0), + latent_shape=sample.latent_shape, + ) + class PosID3D: + """ + Generates 3D positional IDs for video data. + + Attributes: + max_t (int): Maximum number of time frames. + max_h (int): Maximum height dimension. + max_w (int): Maximum width dimension. + """ + def __init__(self, *, max_t=32, max_h=128, max_w=128): self.max_t = max_t self.max_h = max_h @@ -192,6 +356,7 @@ def __init__(self, *, max_t=32, max_h=128, max_w=128): self.generate_pos_id() def generate_pos_id(self): + """Generates a grid of positional IDs based on max_t, max_h, and max_w.""" self.grid = torch.stack( torch.meshgrid( torch.arange(self.max_t, device='cpu'), @@ -202,6 +367,7 @@ def generate_pos_id(self): ) def get_pos_id_3d(self, *, t, h, w): + """Retrieves positional IDs for specified dimensions.""" if t > self.max_t or h > self.max_h or w > self.max_w: self.max_t = max(self.max_t, t) self.max_h = max(self.max_h, h) @@ -210,4 +376,70 @@ def get_pos_id_3d(self, *, t, h, w): return self.grid[:t, :h, :w] +def pad_divisible(x, padding_value=0): + """ + Pads the input tensor to make its size divisible by a specified value. + + Args: + x (torch.Tensor): Input tensor. + padding_value (int): The value to make the tensor size divisible by. + + Returns: + torch.Tensor: Padded tensor. + """ + if padding_value == 0: + return x + # Get the size of the first dimension + n = x.size(0) + + # Compute the padding needed to make the first dimension divisible by 16 + padding_needed = (padding_value - n % padding_value) % padding_value + + if padding_needed <= 0: + return x + + # Create a new shape with the padded first dimension + new_shape = list(x.shape) + new_shape[0] += padding_needed + + # Create a new tensor filled with zeros + x_padded = torch.zeros(new_shape, dtype=x.dtype, device=x.device) + + # Assign the original tensor to the beginning of the new tensor + x_padded[:n] = x + return x_padded + + +def concat_pad(tensor_list, max_seq_length): + """ + Efficiently concatenates a list of tensors along the first dimension and pads with zeros + to reach max_seq_length. + + Args: + tensor_list (list of torch.Tensor): List of tensors to concatenate and pad. + max_seq_length (int): The desired size of the first dimension of the output tensor. + + Returns: + torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions. + """ + import torch + + # Get common properties from the first tensor + other_shape = tensor_list[0].shape[1:] + dtype = tensor_list[0].dtype + device = tensor_list[0].device + + # Initialize the result tensor with zeros + result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device) + + current_index = 0 + for tensor in tensor_list: + length = tensor.shape[0] + # Directly assign the tensor to the result tensor without checks + result[current_index : current_index + length] = tensor + current_index += length + + return result + + pos_id_3d = PosID3D() diff --git a/nemo/collections/diffusion/models/dit/dit_embeddings.py b/nemo/collections/diffusion/models/dit/dit_embeddings.py index ec8d095cbbd4..6303db43bba1 100644 --- a/nemo/collections/diffusion/models/dit/dit_embeddings.py +++ b/nemo/collections/diffusion/models/dit/dit_embeddings.py @@ -55,6 +55,12 @@ def __init__(self, in_channels: int, time_embed_dim: int, seed=None): self.linear_1.reset_parameters() self.linear_2.reset_parameters() + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.linear_1.weight, "pipeline_parallel", True) + setattr(self.linear_1.bias, "pipeline_parallel", True) + setattr(self.linear_2.weight, "pipeline_parallel", True) + setattr(self.linear_2.bias, "pipeline_parallel", True) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Computes the positional embeddings for the input tensor. @@ -152,10 +158,27 @@ def __init__( self.emb_h = torch.nn.Embedding(h, config.hidden_size) self.emb_w = torch.nn.Embedding(w, config.hidden_size) - if config.perform_initialization: - config.init_method(self.emb_t.weight) - config.init_method(self.emb_h.weight) - config.init_method(self.emb_w.weight) + if 'seed' in kwargs.keys(): + seed = kwargs['seed'] + with torch.random.fork_rng(): + torch.manual_seed(seed) + if config.perform_initialization: + self.customize_init_param() + else: + self.reset_parameters() + else: + if config.perform_initialization: + self.customize_init_param() + + def customize_init_param(self): + self.config.init_method(self.emb_t.weight) + self.config.init_method(self.emb_h.weight) + self.config.init_method(self.emb_w.weight) + + def reset_parameters(self): + self.emb_t.reset_parameters() + self.emb_h.reset_parameters() + self.emb_w.reset_parameters() def forward(self, pos_ids: torch.Tensor): return self.emb_t(pos_ids[..., 0]) + self.emb_h(pos_ids[..., 1]) + self.emb_w(pos_ids[..., 2]) diff --git a/nemo/collections/diffusion/models/dit/dit_layer_spec.py b/nemo/collections/diffusion/models/dit/dit_layer_spec.py index cb7c520493f0..2233ef3a7354 100644 --- a/nemo/collections/diffusion/models/dit/dit_layer_spec.py +++ b/nemo/collections/diffusion/models/dit/dit_layer_spec.py @@ -733,8 +733,8 @@ def get_stdit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: ) -def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec: - params = {"attn_mask_type": AttnMaskType.padding} +def get_dit_adaln_block_with_transformer_engine_spec(attn_mask_type=AttnMaskType.padding) -> ModuleSpec: + params = {"attn_mask_type": attn_mask_type} return ModuleSpec( module=DiTLayerWithAdaLN, submodules=DiTWithAdaLNSubmodules( diff --git a/nemo/collections/diffusion/models/dit/dit_model.py b/nemo/collections/diffusion/models/dit/dit_model.py index 0c1c1abc82f2..24943de6d675 100644 --- a/nemo/collections/diffusion/models/dit/dit_model.py +++ b/nemo/collections/diffusion/models/dit/dit_model.py @@ -141,7 +141,7 @@ def __init__( self.config: TransformerConfig = config - self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec(attn_mask_type=config.attn_mask_type) self.pre_process = pre_process self.post_process = post_process self.add_encoder = True @@ -173,19 +173,33 @@ def __init__( dit_embeddings.ParallelTimestepEmbedding(self.config.hidden_size, self.config.hidden_size, seed=1234), ) + self.fps_embedder = nn.Sequential( + Timesteps(num_channels=256, flip_sin_to_cos=False, downscale_freq_shift=1), + ParallelTimestepEmbedding(256, 256, seed=1234), + ) + if self.pre_process: self.x_embedder = torch.nn.Linear(in_channels * patch_spatial**2, self.config.hidden_size) + if pos_embedder is dit_embeddings.SinCosPosEmb3D: + if self.pre_process: + self.pos_embedder = pos_embedder( + config, + t=max_frames // patch_temporal, + h=max_img_h // patch_spatial, + w=max_img_w // patch_spatial, + ) + else: self.pos_embedder = pos_embedder( config, t=max_frames // patch_temporal, h=max_img_h // patch_spatial, w=max_img_w // patch_spatial, + seed=1234, ) - self.fps_embedder = nn.Sequential( - Timesteps(num_channels=256, flip_sin_to_cos=False, downscale_freq_shift=1), - ParallelTimestepEmbedding(256, 256), - ) + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + for p in self.pos_embedder.parameters(): + setattr(p, "pipeline_parallel", True) if self.post_process: self.final_layer_linear = torch.nn.Linear( @@ -194,6 +208,8 @@ def __init__( ) self.affline_norm = RMSNorm(self.config.hidden_size) + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.affline_norm.weight, "pipeline_parallel", True) def forward( self, @@ -223,6 +239,7 @@ def forward( ] * B, dtype=torch.bfloat16, + device=x.device, ), ).view(-1) if self.pre_process: @@ -234,10 +251,16 @@ def forward( else: pos_emb = self.pos_embedder(pos_ids) pos_emb = rearrange(pos_emb, "B S D -> S B D") - x_S_B_D = rearrange(x_B_S_D, "B S D -> S B D") + x_S_B_D = rearrange(x_B_S_D, "B S D -> S B D").contiguous() else: # intermediate stage of pipeline x_S_B_D = None ### should it take encoder_hidden_states + if (not hasattr(self, "pos_embedder")) or isinstance(self.pos_embedder, dit_embeddings.SinCosPosEmb3D): + pos_emb = None + else: + ## if transformer blocks need pos_emb, then pos_embedder should + ## be replicated across pp ranks. + pos_emb = rearrange(self.pos_embedder(pos_ids), "B S D -> S B D").contiguous() timesteps_B_D = self.t_embedder(timesteps.flatten()).to(torch.bfloat16) # (b d_text_embedding) @@ -245,12 +268,17 @@ def forward( fps_B_D = self.fps_embedder(fps) fps_B_D = nn.functional.pad(fps_B_D, (0, self.config.hidden_size - fps_B_D.shape[1])) affline_emb_B_D += fps_B_D + affline_emb_B_D = self.affline_norm(affline_emb_B_D) - crossattn_emb = rearrange(crossattn_emb, 'B S D -> S B D') + crossattn_emb = rearrange(crossattn_emb, 'B S D -> S B D').contiguous() if self.config.sequence_parallel: if self.pre_process: x_S_B_D = tensor_parallel.scatter_to_sequence_parallel_region(x_S_B_D) + if hasattr(self, "pos_embedder") and isinstance( + self.pos_embedder, dit_embeddings.FactorizedLearnable3DEmbedding + ): + pos_emb = tensor_parallel.scatter_to_sequence_parallel_region(pos_emb) crossattn_emb = tensor_parallel.scatter_to_sequence_parallel_region(crossattn_emb) # `scatter_to_sequence_parallel_region` returns a view, which prevents # the original tensor from being garbage collected. Clone to facilitate GC. @@ -309,51 +337,41 @@ def sharded_state_dict( """ sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - for param_name, param in self.t_embedder.named_parameters(): - weight_key = f'{prefix}t_embedder.{param_name}' - self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) - - for param_name, param in self.affline_norm.named_parameters(): - weight_key = f'{prefix}affline_norm.{param_name}' - self.tie_embeddings_weights_state_dict(param, sharded_state_dict, weight_key, weight_key) - + for module in ['t_embedder']: + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f'{prefix}{module}.{param_name}' + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) return sharded_state_dict - def tie_embeddings_weights_state_dict( - self, - tensor, - sharded_state_dict: ShardedStateDict, - output_layer_weight_key: str, - first_stage_word_emb_key: str, + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str ) -> None: - """Ties the embedding and output weights in a given sharded state dict. + """set replica ids of the weights in t_embedder for sharded state dict. Args: sharded_state_dict (ShardedStateDict): state dict with the weight to tie - output_layer_weight_key (str): key of the output layer weight in the state dict. + weight_key (str): key of the weight in the state dict. This entry will be replaced with a tied version - first_stage_word_emb_key (str): this must be the same as the - ShardedTensor.key of the first stage word embeddings. Returns: None, acts in-place """ - if self.pre_process and parallel_state.get_tensor_model_parallel_rank() == 0: - # Output layer is equivalent to the embedding already - return - - # Replace the default output layer with a one sharing the weights with the embedding - del sharded_state_dict[output_layer_weight_key] - last_stage_word_emb_replica_id = ( - 0, # copy of first stage embedding - parallel_state.get_tensor_model_parallel_rank() - + parallel_state.get_pipeline_model_parallel_rank() - * parallel_state.get_pipeline_model_parallel_world_size(), + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + if embedder_weight_key in sharded_state_dict: + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), parallel_state.get_data_parallel_rank(with_context_parallel=True), ) - sharded_state_dict[output_layer_weight_key] = make_sharded_tensor_for_checkpoint( + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( tensor=tensor, - key=first_stage_word_emb_key, - replica_id=last_stage_word_emb_replica_id, + key=embedder_weight_key, + replica_id=replica_id, allow_shape_mismatch=False, ) diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py index 80bed5878e1b..305db1f2c993 100644 --- a/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_layer_spec.py @@ -13,7 +13,7 @@ # limitations under the License. import copy -from typing import Literal +from typing import Literal, Optional from megatron.core.transformer.attention import ( CrossAttention, @@ -22,13 +22,18 @@ SelfAttentionSubmodules, ) from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, + TENorm, + TERowParallelGroupedLinear, TERowParallelLinear, ) from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_block import TransformerConfig from megatron.core.transformer.transformer_config import TransformerConfig @@ -78,7 +83,7 @@ def _replace_no_cp_submodules(submodules): layer_number=layer_number, ) - self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6) # , norm=TENorm) + self.adaLN = AdaLN(config=self.config, n_adaln_chunks=6, norm=TENorm) def forward( self, @@ -138,8 +143,57 @@ def forward( return output, context -def get_dit_llama_spec() -> ModuleSpec: - params = {"attn_mask_type": AttnMaskType.padding} +def _get_mlp_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + fp8: Optional[str] = None, +) -> ModuleSpec: + """Helper function to get module spec for MLP/MoE""" + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + if use_te and moe_grouped_gemm: + linear_fc1 = TEColumnParallelGroupedLinear + linear_fc2 = TERowParallelGroupedLinear + elif use_te and fp8: + linear_fc1 = TEColumnParallelLinear + linear_fc2 = TERowParallelLinear + else: + raise ValueError("Invalid combination of use_te and moe_grouped_gemm") + + use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None + + return ModuleSpec( + module=MoELayer, + submodules=MoESubmodules( + experts=( + MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) + if not moe_grouped_gemm or use_te_grouped_gemm + else None + ), + shared_experts=ModuleSpec( + module=SharedExpertMLP, + params={"gate": False}, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) + + +def get_dit_llama_spec(num_experts=None, attn_mask_type=AttnMaskType.padding) -> ModuleSpec: + params = {"attn_mask_type": attn_mask_type} return ModuleSpec( module=MoviegGenLayer, submodules=TransformerLayerSubmodules( @@ -162,12 +216,6 @@ def get_dit_llama_spec() -> ModuleSpec: linear_proj=TERowParallelLinear, ), ), - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear, - linear_fc2=TERowParallelLinear, - ), - ), + mlp=_get_mlp_module_spec(use_te=True, num_experts=num_experts, moe_grouped_gemm=True, fp8=None), ), ) diff --git a/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py index bfa79e366cac..8ec0c7097c63 100644 --- a/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py +++ b/nemo/collections/diffusion/models/dit_llama/dit_llama_model.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial from typing import Literal from megatron.core.transformer.transformer_config import TransformerConfig @@ -54,7 +54,9 @@ def __init__( patch_temporal=patch_temporal, in_channels=in_channels, out_channels=out_channels, - transformer_decoder_layer_spec=get_dit_llama_spec, + transformer_decoder_layer_spec=partial( + get_dit_llama_spec, num_experts=config.num_moe_experts, attn_mask_type=config.attn_mask_type + ), pos_embedder=dit_embeddings.FactorizedLearnable3DEmbedding, **kwargs, ) diff --git a/nemo/collections/diffusion/models/model.py b/nemo/collections/diffusion/models/model.py index 8cc6be860585..9ee0ab441700 100644 --- a/nemo/collections/diffusion/models/model.py +++ b/nemo/collections/diffusion/models/model.py @@ -14,7 +14,7 @@ import importlib import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Callable, Dict, Optional, Tuple import numpy as np @@ -24,6 +24,7 @@ from einops import rearrange from megatron.core import parallel_state from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig from torch import nn from typing_extensions import override @@ -39,10 +40,12 @@ def dit_forward_step(model, batch) -> torch.Tensor: + """Forward pass of DiT.""" return model(**batch) def dit_data_step(module, dataloader_iter): + """DiT data batch preparation.""" batch = next(dataloader_iter)[0] batch = get_batch_on_this_cp_rank(batch) batch = {k: v.to(device='cuda', non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} @@ -58,12 +61,12 @@ def dit_data_step(module, dataloader_iter): 'self_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), 'cross_attention': PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens_kv, - qkv_format='sbhd', + qkv_format=module.qkv_format, ), } @@ -77,9 +80,7 @@ def get_batch_on_this_cp_rank(data: Dict): cp_size = mpu.get_context_parallel_world_size() cp_rank = mpu.get_context_parallel_rank() - t = 16 if cp_size > 1: - assert t % cp_size == 0, "t must divisibly by cp_size" num_valid_tokens_in_ub = None if 'loss_mask' in data and data['loss_mask'] is not None: num_valid_tokens_in_ub = data['loss_mask'].sum() @@ -88,9 +89,13 @@ def get_batch_on_this_cp_rank(data: Dict): if (value is not None) and (key in ['video', 'video_latent', 'noise_latent', 'pos_ids']): if len(value.shape) > 5: value = value.squeeze(0) - B, C, T, H, W = value.shape + if len(value.shape) == 5: + B, C, T, H, W = value.shape + data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() + else: + B, S, D = value.shape + data[key] = value.view(B, cp_size, S // cp_size, D)[:, cp_rank, ...].contiguous() # TODO: sequence packing - data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous() loss_mask = data["loss_mask"] data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[ :, cp_rank, ... @@ -142,8 +147,16 @@ class DiTConfig(TransformerConfig, io.IOMixin): data_step_fn = dit_data_step forward_step_fn = dit_forward_step + replicated_t_embedder = True + + seq_length: int = 2048 + + qkv_format: str = 'sbhd' + attn_mask_type: AttnMaskType = AttnMaskType.no_mask + @override def configure_model(self, tokenizer=None) -> DiTCrossAttentionModel: + """Configure DiT Model from MCore.""" vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size @@ -168,11 +181,14 @@ def configure_model(self, tokenizer=None) -> DiTCrossAttentionModel: ) def configure_vae(self): + """Dynamically import video tokenizer.""" return dynamic_import(self.vae_module)(self.vae_path) @dataclass class DiTBConfig(DiTConfig): + """DiT-B""" + num_layers: int = 12 hidden_size: int = 768 num_attention_heads: int = 12 @@ -180,6 +196,8 @@ class DiTBConfig(DiTConfig): @dataclass class DiTLConfig(DiTConfig): + """DiT-L""" + num_layers: int = 24 hidden_size: int = 1024 num_attention_heads: int = 16 @@ -187,6 +205,8 @@ class DiTLConfig(DiTConfig): @dataclass class DiTXLConfig(DiTConfig): + """DiT-XL""" + num_layers: int = 28 hidden_size: int = 1152 num_attention_heads: int = 16 @@ -194,6 +214,8 @@ class DiTXLConfig(DiTConfig): @dataclass class DiT7BConfig(DiTConfig): + """DiT-7B""" + num_layers: int = 32 hidden_size: int = 3072 num_attention_heads: int = 24 @@ -201,6 +223,8 @@ class DiT7BConfig(DiTConfig): @dataclass class DiTLlama30BConfig(DiTConfig): + """MovieGen 30B""" + num_layers: int = 48 hidden_size: int = 6144 ffn_hidden_size: int = 16384 @@ -228,13 +252,42 @@ class DiTLlama30BConfig(DiTConfig): @dataclass class DiTLlama5BConfig(DiTLlama30BConfig): + """MovieGen 5B""" + num_layers: int = 32 hidden_size: int = 3072 ffn_hidden_size: int = 8192 num_attention_heads: int = 24 +@dataclass +class DiTLlama1BConfig(DiTLlama30BConfig): + """MovieGen 1B""" + + num_layers: int = 16 + hidden_size: int = 2048 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + + +@dataclass +class ECDiTLlama1BConfig(DiTLlama1BConfig): + "EC-DiT 1B" + moe_router_load_balancing_type: str = 'expert_choice' + moe_token_dispatcher_type: str = 'alltoall' + moe_grouped_gemm: bool = True + moe_expert_capacity_factor: float = 8 + moe_pad_expert_input_to_capacity: bool = True + moe_router_topk: int = 1 + num_moe_experts: int = 64 + ffn_hidden_size: int = 1024 + + class DiTModel(GPTModel): + """ + Diffusion Transformer Model + """ + def __init__( self, config: Optional[DiTConfig] = None, @@ -256,6 +309,9 @@ def __init__( self.vae = None + def load_state_dict(self, state_dict, strict=False): + self.module.load_state_dict(state_dict, strict=False) + def data_step(self, dataloader_iter) -> Dict[str, Any]: return self.config.data_step_fn(dataloader_iter) @@ -284,10 +340,12 @@ def on_validation_start(self): self.vae.to('cuda') def on_validation_end(self): + """Move video tokenizer to CPU after validation.""" if self.vae is not None: self.vae.to('cpu') def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + """Generated validation video sample and logs to wandb.""" # In mcore the loss-function is part of the forward-pass (when labels are provided) state_shape = batch['video'].shape sample = self.diffusion_pipeline.generate_samples_from_batch( @@ -304,7 +362,7 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: seq_len_q = batch['seq_len_q'][0] sample = rearrange( - sample[:, :seq_len_q], + sample[0, None, :seq_len_q], 'B (T H W) (ph pw pt C) -> B C (T pt) (H ph) (W pw)', ph=self.config.patch_spatial, pw=self.config.patch_spatial, @@ -318,13 +376,7 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: video = (video * 255).to(torch.uint8).cpu().numpy().astype(np.uint8) - T = video.shape[2] - if T == 1: - image = rearrange(video, 'b c t h w -> (b t h) w c') - result = image - else: - # result = wandb.Video(video, fps=float(batch['fps'])) # (batch, time, channel, height width) - result = video + result = rearrange(video, 'b c t h w -> (b t) c h w') # wandb is on the last rank for megatron, first rank for nemo wandb_rank = 0 @@ -340,11 +392,12 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: if gather_list is not None: videos = [] for video in gather_list: - if len(video.shape) == 3: - videos.append(wandb.Image(video)) - else: - videos.append(wandb.Video(video, fps=30)) - wandb.log({'prediction': videos}, step=self.global_step) + try: + videos.append(wandb.Video(video, fps=24, format='mp4')) + except Exception as e: + warnings.warn(f'Error saving video as mp4: {e}') + videos.append(wandb.Video(video, fps=24)) + wandb.log({'prediction': videos}) return None @@ -375,6 +428,10 @@ def on_validation_model_zero_grad(self) -> None: class DummyLossReduction(MegatronLossReduction): + """ + Diffusion Loss Reduction + """ + def __init__(self, validation_step: bool = False, val_drop_last: bool = True) -> None: super().__init__() self.validation_step = validation_step diff --git a/nemo/collections/diffusion/sampler/edm/edm_pipeline.py b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py index 6e1be1f6f2a6..16d3177088a9 100644 --- a/nemo/collections/diffusion/sampler/edm/edm_pipeline.py +++ b/nemo/collections/diffusion/sampler/edm/edm_pipeline.py @@ -427,8 +427,13 @@ def get_data_and_condition(self, data_batch: dict[str, Tensor], dropout_rate=0.2 latent_state = raw_state # Condition - data_batch['crossattn_emb'] = self.random_dropout_input( + condition = {} # Create a new dictionary for condition + # Copy all keys from data_batch except 'video' + for key, value in data_batch.items(): + if key not in ['video', 't5_text_embeddings']: + condition[key] = value + condition['crossattn_emb'] = self.random_dropout_input( data_batch['t5_text_embeddings'], dropout_rate=dropout_rate ) - return raw_state, latent_state, data_batch + return raw_state, latent_state, condition diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py index 43a0a5dcb536..5428e0eeefa2 100644 --- a/nemo/collections/diffusion/train.py +++ b/nemo/collections/diffusion/train.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,29 +19,38 @@ import torch from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.enums import AttnMaskType from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm from nemo.collections.diffusion.data.diffusion_energon_datamodule import DiffusionDataModule +from nemo.collections.diffusion.data.diffusion_fake_datamodule import VideoLatentFakeDataModule from nemo.collections.diffusion.data.diffusion_taskencoder import BasicDiffusionTaskEncoder from nemo.collections.diffusion.models.model import ( DiT7BConfig, DiTConfig, DiTLConfig, + DiTLlama1BConfig, DiTLlama5BConfig, DiTLlama30BConfig, DiTModel, DiTXLConfig, + ECDiTLlama1BConfig, ) +from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule from nemo.lightning.pytorch.callbacks import ModelCheckpoint, PreemptionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform +from nemo.lightning.pytorch.callbacks.nsys import NsysCallback from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils.exp_manager import TimingCallback @run.cli.factory @run.autoconvert def multimodal_datamodule() -> pl.LightningDataModule: + """Multimodal Datamodule Initialization""" data_module = DiffusionDataModule( seq_length=2048, task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), @@ -51,9 +60,39 @@ def multimodal_datamodule() -> pl.LightningDataModule: return data_module +@run.cli.factory +@run.autoconvert +def simple_datamodule() -> pl.LightningDataModule: + """Simple Datamodule Initialization""" + data_module = SimpleMultiModalDataModule( + seq_length=2048, + micro_batch_size=1, + global_batch_size=32, + num_workers=16, + tokenizer=None, + image_processor=None, + task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), + ) + return data_module + + +@run.cli.factory +@run.autoconvert +def multimodal_fake_datamodule() -> pl.LightningDataModule: + """Multimodal Mock Datamodule Initialization""" + data_module = VideoLatentFakeDataModule( + seq_length=None, # Set None to dectect the sequence length automatically. + task_encoder=run.Config(BasicDiffusionTaskEncoder, seq_length=2048), + micro_batch_size=1, + global_batch_size=32, + ) + return data_module + + @run.cli.factory @run.autoconvert def peft(args) -> ModelTransform: + """Parameter Efficient Fine Tuning""" return llm.peft.LoRA( target_modules=['linear_qkv', 'linear_proj'], # , 'linear_fc1', 'linear_fc2'], dim=args.lora_dim, @@ -62,6 +101,7 @@ def peft(args) -> ModelTransform: @run.cli.factory(target=llm.train) def pretrain() -> run.Partial: + """Base Pretraining Config""" return run.Partial( llm.train, model=run.Config( @@ -85,6 +125,8 @@ def pretrain() -> run.Partial: DistributedDataParallelConfig, check_for_nan_in_grad=True, grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, ), ), plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), @@ -96,12 +138,18 @@ def pretrain() -> run.Partial: callbacks=[ run.Config( ModelCheckpoint, - monitor='reduced_train_loss', - filename='{epoch}-{step}', + monitor='global_step', + filename='{global_step}', every_n_train_steps=1000, - save_top_k=-1, + save_top_k=3, + mode='max', ), run.Config(PreemptionCallback), + run.Config(TimingCallback), + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ), ], ), log=nl.NeMoLogger(wandb=(WandbLogger() if "WANDB_API_KEY" in os.environ else None)), @@ -129,6 +177,7 @@ def pretrain() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_xl() -> run.Partial: + """DiT-XL Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiTXLConfig) return recipe @@ -136,13 +185,89 @@ def pretrain_xl() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_l() -> run.Partial: + """DiT-L Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiTLConfig) return recipe +@run.cli.factory(target=llm.train) +def train_mock() -> run.Partial: + """DiT Mock Pretraining Recipe""" + recipe = pretrain() + recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) + recipe.data = multimodal_fake_datamodule() + recipe.model.config.num_layers = 16 + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.context_parallel_size = 2 + recipe.data.micro_batch_size = 1 + recipe.data.global_batch_size = 1 + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/train_mock' + + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + + return recipe + + +@run.cli.factory(target=llm.train) +def mock_ditllama5b_8k() -> run.Partial: + recipe = pretrain() + recipe.model.config = run.Config(DiTLlama5BConfig, max_frames=1) + recipe.data = multimodal_fake_datamodule() + recipe.data.seq_length = recipe.data.task_encoder.seq_length = 8192 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.context_parallel_size = 1 + recipe.data.micro_batch_size = 1 + recipe.data.global_batch_size = 32 + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/mock_ditllama5b_8k' + recipe.model.config.attn_mask_type = AttnMaskType.no_mask + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + recipe.trainer.max_steps = 15 + recipe.trainer.callbacks.pop(0) + recipe.trainer.enable_checkpointing = False + recipe.trainer.callbacks.append( + run.Config( + NsysCallback, + start_step=10, + end_step=11, + ) + ) + recipe.resume = None + return recipe + + +@run.cli.factory(target=llm.train) +def mock_dit7b_8k() -> run.Partial: + recipe = mock_ditllama5b_8k() + recipe.model.config = run.Config(DiT7BConfig, max_frames=1) + recipe.data.model_config = recipe.model.config + recipe.model.config.attn_mask_type = AttnMaskType.no_mask + recipe.model.config.use_cpu_initialization = True + recipe.log.log_dir = 'nemo_experiments/mock_dit7b_8k' + return recipe + + @run.cli.factory(target=llm.train) def pretrain_7b() -> run.Partial: + """DiT-7B Pretraining Recipe""" recipe = pretrain() recipe.model.config = run.Config(DiT7BConfig) recipe.data.global_batch_size = 4608 @@ -161,8 +286,59 @@ def pretrain_7b() -> run.Partial: return recipe +@run.cli.factory(target=llm.train) +def pretrain_7b_pack() -> run.Partial: + """DiT-7B Pretraining Recipe with Packing""" + recipe = pretrain_7b() + recipe.data.global_batch_size = 4608 // 9 + recipe.data.micro_batch_size = 1 + recipe.data.num_workers = 15 + recipe.data.use_train_split_for_val = True + recipe.data.seq_length = 256 * 9 + recipe.data.packing_buffer_size = 1000 + recipe.data.task_encoder.seq_length = None + recipe.data.task_encoder.max_seq_length = recipe.data.seq_length + recipe.model.config.qkv_format = 'thd' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_7b_256p_joint() -> run.Partial: + """DiT-7B Pretraining Recipe 256p Stage 1""" + recipe = pretrain_7b() + recipe.data.global_batch_size = 256 # 768 + recipe.data.micro_batch_size = 1 + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.model.config.seq_length = 8192 + + recipe.optim.config.lr = 6e-5 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + # recipe.resume.restore_config = run.Config(RestoreConfig, path='', load_optim_state=True) + recipe.log.log_dir = 'nemo_experiments/pretrain_7b_256p_joint' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_7b_256p_joint_pack() -> run.Partial: + """DiT-7B Pretraining Recipe 256p Stage 1 with Packing""" + recipe = pretrain_7b_256p_joint() + recipe.data.global_batch_size = 128 + recipe.data.micro_batch_size = 1 + recipe.data.num_workers = 10 + recipe.data.seq_length = recipe.model.config.seq_length = recipe.data.task_encoder.max_seq_length = 10240 + recipe.data.task_encoder.seq_length = None + recipe.data.packing_buffer_size = 1000 + recipe.data.virtual_epoch_length = 0 + recipe.model.config.qkv_format = 'thd' + return recipe + + @run.cli.factory(target=llm.train) def pretrain_ditllama5b() -> run.Partial: + """MovieGen 5B Training""" recipe = pretrain_7b() recipe.data.micro_batch_size = 12 recipe.model.config = run.Config(DiTLlama5BConfig) @@ -172,30 +348,200 @@ def pretrain_ditllama5b() -> run.Partial: @run.cli.factory(target=llm.train) def pretrain_ditllama30b() -> run.Partial: + """MovieGen 30B Stage 1 Training""" recipe = pretrain_ditllama5b() recipe.model.config = run.Config(DiTLlama30BConfig) recipe.data.global_batch_size = 9216 recipe.data.micro_batch_size = 6 - recipe.log.log_dir = 'nemo_experiments/ditllama30b' + recipe.data.task_encoder.aethetic_score = 4.0 + recipe.data.seq_length = 256 + recipe.data.task_encoder.seq_length = 256 + recipe.data.virtual_epoch_length = 0 + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage1_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage2_mock() -> run.Partial: + """MovieGen 30B Stage 2 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 4 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage2_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage3_mock() -> run.Partial: + """MovieGen 30B Stage 3 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock' + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama5b_stage3_mock_with_pp() -> run.Partial: + """MovieGen 30B Stage 3 Training""" + recipe = pretrain_ditllama5b() + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 8192 + recipe.data.task_encoder.seq_length = 8192 + recipe.data.global_batch_size = 1 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 2 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage5_mock_with_pp' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama30b_stage3_mock_with_pp() -> run.Partial: + """MovieGen 30B Stage 3 Training with Pipeline Parallelism""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama30BConfig) + recipe.data = multimodal_fake_datamodule() + recipe.data.model_config = recipe.model.config + recipe.data.seq_length = 73728 + recipe.data.task_encoder.seq_length = 73728 + recipe.data.global_batch_size = 256 + recipe.data.micro_batch_size = 1 + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.context_parallel_size = 8 + recipe.trainer.strategy.sequence_parallel = True + recipe.trainer.limit_val_batches = 0 + recipe.trainer.val_check_interval = 1.0 + recipe.data.model_config = recipe.model.config + recipe.log.log_dir = 'nemo_experiments/ditllama30b_stage3_mock_with_pp' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama1b() -> run.Partial: + """MovieGen 1B Stage 1 Training""" + recipe = pretrain_ditllama5b() + recipe.model.config = run.Config(DiTLlama1BConfig) + recipe.data.task_encoder.aethetic_score = 4.0 + recipe.data.seq_length = 256 + recipe.data.task_encoder.seq_length = 256 + recipe.model.config.seq_length = 256 + recipe.data.global_batch_size = 1536 + recipe.data.micro_batch_size = 96 + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.log.log_dir = 'nemo_experiments/ditllama1b' + recipe.trainer.val_check_interval = 3000 + recipe.trainer.callbacks[0].every_n_train_steps = 3000 + recipe.trainer.callbacks[0].monitor = 'global_step' + recipe.trainer.callbacks[0].save_top_k = 3 + recipe.trainer.callbacks[0].mode = 'max' + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ditllama3b() -> run.Partial: + """MovieGen 3B Stage 1 Training""" + recipe = pretrain_ditllama1b() + recipe.data.micro_batch_size = 48 + recipe.model.config = run.Config( + DiTLlama1BConfig, + hidden_size=3072, + num_layers=28, + num_attention_heads=24, + ffn_hidden_size=8192, + ) + recipe.log.log_dir = 'nemo_experiments/ditllama3b' + + return recipe + + +@run.cli.factory(target=llm.train) +def pretrain_ecditllama1b() -> run.Partial: + """EC-DiT 1B Training""" + recipe = pretrain_ditllama1b() + recipe.data.task_encoder.aethetic_score = 5.0 + recipe.data.micro_batch_size = 72 + recipe.data.global_batch_size = 2304 + recipe.model.config = run.Config(ECDiTLlama1BConfig) + recipe.log.log_dir = 'nemo_experiments/ecditllama1b' + recipe.trainer.val_check_interval = 3000 + + recipe.trainer.strategy.ddp.with_megatron_fsdp_code_path = True + recipe.trainer.strategy.ddp.data_parallel_sharding_strategy = 'MODEL_AND_OPTIMIZER_STATES' + recipe.trainer.strategy.ddp.overlap_param_gather = True + recipe.trainer.strategy.ddp.overlap_grad_reduce = True + recipe.model.config.use_cpu_initialization = True + return recipe @run.cli.factory(target=llm.train) def dreambooth() -> run.Partial: + """Dreambooth Fine Tuning""" recipe = pretrain() recipe.optim.config.lr = 1e-6 recipe.data = multimodal_datamodule() recipe.model.config = run.Config(DiTConfig) - recipe.trainer.max_steps = 1000 recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.sequence_parallel = True - recipe.resume.restore_config = run.Config(RestoreConfig) recipe.resume.resume_if_exists = False - return recipe if __name__ == "__main__": + OOM_DEBUG = False + if OOM_DEBUG: + torch.cuda.memory._record_memory_history( + True, + # Keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # Record stack information for the trace events + trace_alloc_record_context=True, + ) run.cli.main(llm.train, default_factory=dreambooth) From f311b2ea33236b5c1d781ba5abc1ea325809180d Mon Sep 17 00:00:00 2001 From: gvenkatakris Date: Tue, 12 Nov 2024 23:15:13 -0800 Subject: [PATCH 15/24] Update pruning and distillation tutorial notebooks (#11091) * Update pruning and distillation tutorial notebooks Signed-off-by: Gomathy Venkata Krishnan * Update README Signed-off-by: Gomathy Venkata Krishnan * Update batch size in width pruning script Signed-off-by: Gomathy Venkata Krishnan * Update README Signed-off-by: Gomathy Venkata Krishnan --------- Signed-off-by: Gomathy Venkata Krishnan --- tutorials/llm/llama-3/README.rst | 4 +- .../01_data_preparation.ipynb | 102 +++ .../02_teacher_finetuning.ipynb | 153 +++++ .../03_a_depth_pruning.ipynb | 77 +++ .../03_b_width_pruning.ipynb | 92 +++ ...04_a_distilling_depth_pruned_student.ipynb | 136 ++++ ...04_b_distilling_width_pruned_student.ipynb | 138 ++++ .../05_display_results.ipynb | 168 +++++ .../llama-3/pruning-distillation/README.rst | 53 +- .../pruning-distillation/introduction.ipynb | 190 ++++++ .../llama3-pruning-distillation-nemofw.ipynb | 587 ------------------ 11 files changed, 1100 insertions(+), 600 deletions(-) create mode 100644 tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb create mode 100644 tutorials/llm/llama-3/pruning-distillation/introduction.ipynb delete mode 100644 tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama-3/README.rst index bb6171e6f582..3bb1a0896b82 100755 --- a/tutorials/llm/llama-3/README.rst +++ b/tutorials/llm/llama-3/README.rst @@ -17,6 +17,6 @@ This repository contains jupyter notebook tutorials using NeMo Framework for Lla * - `Llama 3.1 Law-Domain LoRA Fine-Tuning and Deployment with NeMo Framework and NVIDIA NIM <./sdg-law-title-generation>`_ - `Law StackExchange `_ - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a pre-requisite, follow the tutorial for `data curation using NeMo Curator `__. - * - `Llama 3.1 WikiText Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ + * - `Llama 3.1 Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ - `WikiText-103-v1 `_ - - Perform pruning and distillation on Llama 3.1 8B Instruct using the WikiText-103-v1 dataset with NeMo Framework. + - Perform pruning and distillation on Llama 3.1 8B using the WikiText-103-v1 dataset with NeMo Framework. diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb new file mode 100644 index 000000000000..1f84dd2719e6 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb @@ -0,0 +1,102 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ab9e2e97-7f10-4353-859e-693842bde465", + "metadata": {}, + "source": [ + "### Step 1: Prepare the dataset\n", + "\n", + "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6505c00b-9eb4-4087-9e49-423f6228e690", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-train.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_train \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb1aa80f-70bc-4dff-8b08-3bff48d9a1c3", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-test.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_test \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42bec54a-94f6-4c87-8e14-2726ef6c2625", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input=\"./wikitext-data/wikitext-val.jsonl\" \\\n", + "--tokenizer-library='huggingface' \\\n", + "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", + "--output-prefix=wikitext_tokenized_val \\\n", + "--append-eod \\\n", + "--workers=32" + ] + }, + { + "cell_type": "markdown", + "id": "5d77ee8a-e0dc-44f7-b5e8-3b6025d979d7", + "metadata": {}, + "source": [ + "After running the above scripts, you will see the preprocesed `wikitext_tokenized_{train/val/test}_text_document.{idx/bin}`files. These output files will be used in the next step." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb new file mode 100644 index 000000000000..8d08793bbe9a --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb @@ -0,0 +1,153 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "84b146ba-08b6-4adb-a858-8e4294c5e781", + "metadata": {}, + "source": [ + "\n", + "### Step 2: Finetune the teacher on the dataset\n", + "\n", + "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", + "\n", + "We finetune the unpruned model on our dataset to correct the distribution shift across the original dataset the model was trained on. Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that, without correcting for the distribution shift, the teacher provides suboptimal guidance on the dataset when being distilled.\n", + "\n", + "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12007ac8-2fd5-4de8-8964-97821c2198c0", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Set path(s) if different:\n", + "\n", + "MODEL=\"/workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo\"\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_ft\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", + " --config-path /opt/NeMo/examples/nlp/language_modeling/conf/ \\\n", + " --config-name megatron_llama_distill.yaml \\\n", + " \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${MODEL} \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " \\\n", + " ~model.tokenizer \\\n", + " +model.tokenizer='{library: huggingface, type: meta-llama/Meta-Llama-3.1-8B, use_fast: True}' \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.encoder_seq_length=8192 \\\n", + " model.num_layers=32 \\\n", + " model.hidden_size=4096 \\\n", + " model.ffn_hidden_size=14336 \\\n", + " model.num_attention_heads=32 \\\n", + " model.hidden_dropout=0.0 \\\n", + " model.attention_dropout=0.0 \\\n", + " model.apply_query_key_layer_scaling=True \\\n", + " model.normalization='rmsnorm' \\\n", + " model.bias=False \\\n", + " model.activation='fast-swiglu' \\\n", + " model.position_embedding_type='rope' \\\n", + " model.share_embeddings_and_output_weights=False \\\n", + " model.num_query_groups=8 \\\n", + " ++model.scale_positional_embedding=True \\\n", + " ++model.rotary_base=500000.0 \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "3040a993-8423-475f-8bc6-d1dd1ce16a83", + "metadata": {}, + "source": [ + "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb new file mode 100644 index 000000000000..a195c2f3a405 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb @@ -0,0 +1,77 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Prune the finetuned-teacher model to create a student\n", + "In this step, we will explore two methods to prune the finetuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", + "\n", + "In the first method, depth-pruning, we trim the layers of the model." + ] + }, + { + "cell_type": "markdown", + "id": "72fa494e-6268-4044-a1d6-c0518d450cfd", + "metadata": {}, + "source": [ + "#### Step 3.a.: Using depth-pruning \n", + "To depth-prune, we will trim the last 16 layers in the finetined teacher model. For depth-pruning, we would be using the [megatron_gpt_drop_layers](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_drop_layers.py) script. \n", + "\n", + "Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60cae073-a192-4d47-b220-b09736d39a93", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!python -m torch.distributed.launch --nproc_per_node=8 \\\n", + " /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \\\n", + " --path_to_nemo \"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " --path_to_save \"/workspace/4b_depth_pruned_model.nemo\" \\\n", + " --tensor_model_parallel_size 8 \\\n", + " --pipeline_model_parallel_size 1 \\\n", + " --gpus_per_node 8 \\\n", + " --drop_layers 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31" + ] + }, + { + "cell_type": "markdown", + "id": "375f298a-0363-4f44-b40c-2c8e9bab7d76", + "metadata": {}, + "source": [ + "Running this script will save the depth-pruned model `4b_depth_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb new file mode 100644 index 000000000000..7d91d36cbb32 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb @@ -0,0 +1,92 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", + "metadata": {}, + "source": [ + "### Step 3: Prune the finetuned-teacher model to create a student\n", + "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads and embedding channels. \n", + "\n", + "Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore." + ] + }, + { + "cell_type": "markdown", + "id": "9207ed14-2f37-4712-88f3-543a128663ac", + "metadata": { + "tags": [] + }, + "source": [ + "#### Step 3.b.: Using width-pruning\n", + "To width-prune the model, we do the following:\n", + "- prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", + "- prune the hidden size from 4096 to 3072.\n", + "- and retrain the attention headcount and number of layers\n", + "\n", + "For width-pruning we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", + "\n", + "We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model.\n", + "\n", + "> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "571d1483-dd4c-403e-b321-293342e7a62a", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!torchrun --nproc-per-node=8 /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_prune.py \\\n", + " model.restore_from_path=\"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", + " model.tensor_model_parallel_size=1 \\\n", + " model.pipeline_model_parallel_size=8 \\\n", + " +model.dist_ckpt_load_strictness=log_all \\\n", + " inference.batch_size=64 \\\n", + " trainer.num_nodes=1 \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=8 \\\n", + " prune.ffn_hidden_size=9216 \\\n", + " prune.num_attention_heads=null \\\n", + " prune.num_query_groups=null \\\n", + " prune.hidden_size=3072 \\\n", + " export.save_path=\"/workspace/4b_width_pruned_model.nemo\"" + ] + }, + { + "cell_type": "markdown", + "id": "e9fb0977-5c02-4ecc-b602-54d74b2e2184", + "metadata": {}, + "source": [ + "Running this script will save the width-pruned model `4b_width_pruned_model.nemo` to your workspace." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb new file mode 100644 index 000000000000..ccbe1cbf394b --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb @@ -0,0 +1,136 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "09d30e35-8e9d-4d2e-bd14-738c627a3963", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model. \n", + "\n", + "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "c33cf641-0d27-417f-b3ee-c06701698184", + "metadata": {}, + "source": [ + "#### Step 4.a.: Using depth-pruned student\n", + "While distilling knowledge from the teacher to depth-pruned model, the `STUDENT` model would be `4b_depth_pruned_model.nemo` as produced by the [depth-pruning](./03_a_depth_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d23a01e-4912-47cb-bf21-b4fd72007ec1", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_depth_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_depth_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/depth_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS}" + ] + }, + { + "cell_type": "markdown", + "id": "42d910d9-14dd-44ba-bf2c-0064737c70fa", + "metadata": {}, + "source": [ + "This will create the final distilled model named `depth_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_depth_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 35 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb new file mode 100644 index 000000000000..48e81c96cdcf --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5062f23-c604-479b-9a4e-69989598b131", + "metadata": {}, + "source": [ + "### Step 4: Distill knowledge from teacher into student\n", + "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", + "In this notebook, we will explore distillation with the width-pruned model as the `STUDENT` model.\n", + "\n", + "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + ] + }, + { + "cell_type": "markdown", + "id": "be7de691-dd1d-4719-9872-98501a22e3c9", + "metadata": {}, + "source": [ + "#### Step 4.b.: Using width-pruned student\n", + "While distilling knowledge from the teacher to width-pruned model, the `STUDENT` model would be `4b_width_pruned_model.nemo` as produced by the [width-pruning](./03_b_width_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", + "\n", + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0070b526-771a-4a8d-b0ba-ab218b382bd9", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash \n", + "\n", + "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", + "\n", + "# Can change these to accommodate resources:\n", + "\n", + "TENSOR_PARALLEL_SIZE=8\n", + "NODES=1\n", + "MICRO_BATCH_SIZE=4\n", + "\n", + "# Don't change the following:\n", + "\n", + "EXPERIMENT_DIR=\"distill_trainings\"\n", + "EXPERIMENT_NAME=\"megatron_llama_distill_width_pruned_student\"\n", + "\n", + "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", + "STUDENT=\"/workspace/4b_width_pruned_model.nemo\"\n", + "\n", + "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/width_pruned_distilled_4b_model.nemo\"\n", + "\n", + "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", + "DATA_VAL='wikitext_tokenized_test_text_document'\n", + "DATA_TEST='wikitext_tokenized_val_text_document'\n", + "\n", + "STEPS=30\n", + "GLOBAL_BATCH_SIZE=128\n", + "\n", + "LOG_INTERVAL=1\n", + "VAL_INTERVAL=10\n", + "NUM_VAL_BATCHES=5\n", + "\n", + "LR=1e-4\n", + "MIN_LR=1e-5\n", + "WARMUP_STEPS=2\n", + "\n", + "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", + "\n", + "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", + " name=${EXPERIMENT_NAME} \\\n", + " \\\n", + " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", + " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", + " \\\n", + " trainer.max_steps=${STEPS} \\\n", + " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", + " trainer.val_check_interval=${VAL_INTERVAL} \\\n", + " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", + " +trainer.num_sanity_val_steps=0 \\\n", + " \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", + " trainer.num_nodes=${NODES} \\\n", + " \\\n", + " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", + " \\\n", + " model.restore_from_path=${STUDENT} \\\n", + " model.kd_teacher_restore_from_path=${TEACHER} \\\n", + " model.nemo_path=${FINAL_MODEL_PATH} \\\n", + " \\\n", + " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", + " model.sequence_parallel=True \\\n", + " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", + " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", + " \\\n", + " model.optim.name=distributed_fused_adam \\\n", + " model.optim.lr=${LR} \\\n", + " model.optim.sched.min_lr=${MIN_LR} \\\n", + " model.optim.sched.warmup_steps=${WARMUP_STEPS} \\\n", + " +model.dist_ckpt_load_strictness=log_all" + ] + }, + { + "cell_type": "markdown", + "id": "d9dbc377-e19a-49e0-b245-fa828cca415a", + "metadata": {}, + "source": [ + "This will create the final width-pruned distilled model named `width_pruned_distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill_width_pruned_student/checkpoints`.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depends on GPU) and generate the final distilled model." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb new file mode 100644 index 000000000000..0264cc288957 --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6c91263b-b312-4ab2-b13f-0ee4b6e8bd0f", + "metadata": {}, + "source": [ + "### Step 5: Display the validation loss\n", + "\n", + "Now that the results are in, let's visualize the validation loss of the two distilled models using the `tensorboard` library. \n", + "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." + ] + }, + { + "cell_type": "markdown", + "id": "b5822d62-8131-4046-8c22-0bf0fce81df7", + "metadata": {}, + "source": [ + "#### Validation Loss using depth-pruned model as student in distillation script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the depth-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a665fe1-df45-4126-8694-f182af113133", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_depth_pruned_student/\" --port=6007" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "db6fcf26-8ae8-40e1-875a-0a10bf85be81", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Depth-Pruned model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Depth-Pruned model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png\", width=400))" + ] + }, + { + "cell_type": "markdown", + "id": "f10041ae-6533-47de-9f76-f97d4469c27a", + "metadata": {}, + "source": [ + "#### Validation Loss using width-pruned model as student in distillation script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the width-pruned student." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b0c3118-4987-4df3-88bd-fcffdb521c5d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir \"distill_trainings/megatron_llama_distill_width_pruned_student/\" --port=6008" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ecd79583-f662-40c6-a690-9f4bb847de4e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
Validation Loss over 30 Training Steps with Width-Pruned model as Student
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, display, HTML\n", + "title = \"Validation Loss over 30 Training Steps with Width-Pruned model as Student\"\n", + "display(HTML(f\"
{title}
\"))\n", + "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png\", width=400))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ab6ed6f-8bc3-4188-919f-7cee842635ed", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama-3/pruning-distillation/README.rst index 9d4207a5c968..34febcffa366 100644 --- a/tutorials/llm/llama-3/pruning-distillation/README.rst +++ b/tutorials/llm/llama-3/pruning-distillation/README.rst @@ -1,18 +1,26 @@ -Llama 3.1 WikiText Pruning and Distillation with NeMo Framework +Llama 3.1 Pruning and Distillation with NeMo Framework ======================================================================================= `Llama 3.1 `_ are open-source large language models by Meta that deliver state-of-the-art performance on popular industry benchmarks. They have been pretrained on over 15 trillion tokens, and support a 128K token context length. They are available in three sizes, 8B, 70B, and 405B, and each size has two variants—base pretrained and instruction tuned. `NVIDIA NeMo Framework `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 to fit your use case. +`NVIDIA TensorRT Model Optimizer `_ is a library (referred to as **Model Optimizer**, or **ModelOpt**) comprising state-of-the-art model optimization techniques including `quantization `_, `sparsity `_, `distillation `_, and `pruning `_ to compress models. + `LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 as described in the `tech report `_. +`How to Prune and Distill Llama-3.1 8B to an NVIDIA Llama-3.1-Minitron 4B Model `_ provides practical and effective structured compression best practices for LLMs that combine depth, width, attention, and MLP pruning with knowledge distillation-based retraining. These strategies are presented in the `Compact Language Models via Pruning and Knowledge Distillation `_ paper. + +`Mistral-NeMo-Minitron 8B Model Delivers Unparalleled Accuracy `_ introduces the Mistral-NeMo-Minitron 8B, a state-of-the-art 8 billion parameter language model created by pruning and distilling the larger Mistral NeMo 12B model. + Objectives ---------- -This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B Instruct** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform a light finetuning procedure on the ``Meta Llama 3.1 8B Instruct`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then depth-pruned to create a trimmed model ``4b_trimmed_model.nemo``. These models will serve as a starting point for distillation to create a final distilled 4B model. +This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform teacher correction by running a light finetuning procedure on the ``Meta Llama 3.1 8B`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will be exploring both pruning techniques which will yield ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo`` respectively. These models will serve as a starting point for distillation to create the final distilled 4B models. We are using models utilizing the ``meta-llama/Meta-Llama-3.1-8B`` tokenizer for this demonstration. +``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable but will be supported in future releases. + Requirements ------------- @@ -31,14 +39,16 @@ Create a pruned and distilled model with NeMo Framework For pruning and distilling the model, you will use the NeMo Framework which is available as a `docker container `_. +``NOTE:`` These notebooks use `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. + -1. Download the `Llama 3.1 8B Instruct .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-instruct-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. +1. Download the `Llama 3.1 8B .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. .. code:: bash - ngc registry model download-version "nvidia/nemo/llama-3_1-8b-instruct-nemo:1.0" + ngc registry model download-version "nvidia/nemo/llama-3_1-8b-nemo:1.0" -2. Run the container using the following command. It is assumed that you have the dataset, notebook(s), and the ``llama-3.1-8b-instruct`` model available in the current directory. If not, mount the appropriate folder to ``/workspace``. +2. Run the container using the following command. It is assumed that you have the dataset, notebook(s), and the ``llama3_1_8b.nemo`` model available in the current directory. If not, mount the appropriate folder to ``/workspace``. .. code:: bash @@ -63,17 +73,38 @@ For pruning and distilling the model, you will use the NeMo Framework which is a jupyter lab --ip 0.0.0.0 --port=8888 --allow-root -4. Then, navigate to `this notebook <./llama3-pruning-distillation-nemofw.ipynb>`_. +4. Then, navigate to `this notebook <./introduction.ipynb>`_ to get started. +This directory contains a list of notebooks which will go over all the steps to create a distilled 4B model. + +:: + + <$pruning_distillation> + └── introduction.ipynb + └── 01_data_preparation.ipynb + └── 02_teacher_finetuning.ipynb + └── 03_a_depth_pruning.ipynb + └── 03_b_width_pruning.ipynb + └── 04_a_distilling_depth_pruned_student.ipynb + └── 04_b_distilling_width_pruned_student.ipynb + └── 05_display_results.ipynb + Results ------------------------------------------------------------------------------ -``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. +``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. + +Here are the validation loss plots over 30 steps of running the training step in the distillation script (at the end of the `notebook <./05_display_results.ipynb>`_). -Here is the validation loss over 30 steps of running the training step in the distillation script (at the end of the `notebook <./llama3-pruning-distillation-nemofw.ipynb>`_). +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png + :width: 400px + :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the depth-pruned model as the student + :align: center -.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_distillation.png + Figure 1: Validation Loss Plot when using the depth-pruned model as the student + +.. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png :width: 400px - :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script + :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the width-pruned model as the student :align: center - Figure 1: Validation Loss Plot \ No newline at end of file + Figure 2: Validation Loss Plot when using the width-pruned model as the student \ No newline at end of file diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb new file mode 100644 index 000000000000..1a3efc9f5f1e --- /dev/null +++ b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb @@ -0,0 +1,190 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "411e6711-60fc-4488-8aa1-c6463cac8695", + "metadata": { + "tags": [] + }, + "source": [ + "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" + ] + }, + { + "cell_type": "markdown", + "id": "03fd1cf4-c67a-4b8d-a5e5-46531be0f991", + "metadata": {}, + "source": [ + "This demonstration showcases performing pruning and distillation on **Llama 3.1-8B** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified 'Good' and 'Featured' articles on Wikipedia. \n", + "\n", + "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed. There are two methods to prune a model: depth-pruning and width-pruning. This workflow will showcase both methods which will yield `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo` respectively, that will serve as a starting point for distillation to the final 4B models. \n", + "\n", + "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", + "\n", + "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", + "\n", + "**Instructions are available in the associated tutorial README to download the model and the container.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a5026ce-39f1-43e3-93af-4c4f1e9da1f2", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install --upgrade ipywidgets notebook\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "afe59b07-bb48-4913-90cc-bb416b48196c", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Prerequisites\n", + "Ensure you have the following -\n", + "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9d48b81-e978-4894-8ba4-4f183f698bb1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!ls /workspace/llama-3_1-8b-nemo_v1.0/llama3_1_8b.nemo" + ] + }, + { + "cell_type": "markdown", + "id": "7129d44e-0536-4e62-bdbc-0f1ad44dc84a", + "metadata": {}, + "source": [ + "2. **Set the Hugging Face Access Token**: You can obtain this from your [Hugging Face account](https://huggingface.co/docs/hub/en/security-tokens). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "481417ed-1456-4962-8f67-4350bde1aabd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "id": "245eda8d-c999-431e-9ebc-5c92c4f21f3b", + "metadata": {}, + "source": [ + "3. **Obtain the dataset**: Generate the `wikitext-{train/val/test}.jsonl` splits after loading the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eaef2c7d-41f7-41ad-a76a-2d714e9c35de", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# Split into train, test and val files\n", + "\n", + "import json\n", + "import os\n", + "from datasets import load_dataset\n", + "\n", + "# Load the WikiText-103 dataset\n", + "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\")\n", + "\n", + "# Define the destination folder\n", + "data_folder = 'wikitext-data'\n", + "os.makedirs(data_folder, exist_ok=True)\n", + "\n", + "# Define file paths and destination paths\n", + "file_paths = {\n", + " 'train': os.path.join(data_folder, 'wikitext-train.jsonl'),\n", + " 'validation': os.path.join(data_folder, 'wikitext-val.jsonl'),\n", + " 'test': os.path.join(data_folder, 'wikitext-test.jsonl')\n", + "}\n", + "\n", + "# Function to save dataset split to a JSONL file\n", + "def save_to_jsonl(file_path, data):\n", + " with open(file_path, 'w') as file:\n", + " for item in data:\n", + " file.write(json.dumps(item) + '\\n')\n", + "\n", + "# Define splits\n", + "splits = [\"train\", \"validation\", \"test\"]\n", + "\n", + "# Save splits to JSONL files and calculate their sizes\n", + "for split in splits:\n", + " if split in dataset:\n", + " save_to_jsonl(file_paths[split], dataset[split])\n", + " else:\n", + " print(f\"Split {split} not found in the dataset.\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "2d0cc359-0598-40aa-af80-9503ecd4dac1", + "metadata": { + "tags": [] + }, + "source": [ + "---\n", + "## Step-by-step instructions\n", + "\n", + "This workflow is structured into seven notebooks:\n", + "1. [Prepare the dataset](./01_data_preparation.ipynb)\n", + "2. [Finetune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", + "3. Prune the finetuned-teacher model to create a student \n", + " - 3.a. [Using depth-pruning](./03_a_depth_pruning.ipynb)\n", + " - 3.b. [Using width-pruning](./03_b_width_pruning.ipynb)\n", + "4. Distill knowledge from teacher into student\n", + " - 4.a. [Using depth-pruned student](./04_a_distilling_depth_pruned_student.ipynb)\n", + " - 4.b. [Using width-pruned student](./04_b_distilling_width_pruned_student.ipynb)\n", + "5. [Display the validation loss](./05_display_results.ipynb)\n", + "\n", + "> `NOTE:` We are exploring two methods to prune the finetuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb b/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb deleted file mode 100644 index 8b31ad4de018..000000000000 --- a/tutorials/llm/llama-3/pruning-distillation/llama3-pruning-distillation-nemofw.ipynb +++ /dev/null @@ -1,587 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "363a6974-810c-41c5-84da-4751a92fb72b", - "metadata": { - "tags": [] - }, - "source": [ - "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" - ] - }, - { - "cell_type": "markdown", - "id": "c6d4ed6d-8ecd-4647-bd0a-e48fec64c199", - "metadata": {}, - "source": [ - "This notebook showcases performing pruning and distillation on **Llama 3.1-8B-Instruct** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. \n", - "\n", - "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B Instruct` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed to create a depth-pruned model `4b_trimmed_model.nemo` that will serve as a starting point for distillation to a final 4B model. \n", - "\n", - "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", - "\n", - "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", - "\n", - "**Instructions are available in the associated tutorial README to download the model and the container.**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1d0dc714-5bbf-4266-805a-9841ff486c05", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!pip install --upgrade ipywidgets notebook\n", - "!pip install datasets" - ] - }, - { - "cell_type": "markdown", - "id": "2658505d-7990-40a5-a269-866ddd8a0181", - "metadata": { - "tags": [] - }, - "source": [ - "---\n", - "## Prerequisites\n", - "Ensure you have the following -\n", - "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B Instruct .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a30cfe8a-87a8-4511-be5f-e20d7fe558d4", - "metadata": {}, - "outputs": [], - "source": [ - "!ls /workspace/llama-3_1-8b-instruct-nemo_v1.0" - ] - }, - { - "cell_type": "markdown", - "id": "251a670e-9636-4807-bc98-a91c6137454d", - "metadata": {}, - "source": [ - "2. **Set the Hugging Face Access Token**: You can obtain this from your [Hugging Face account](https://huggingface.co/docs/hub/en/security-tokens). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47d7887d-b582-4a1e-81cd-fdc1be8d9afb", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from huggingface_hub import login\n", - "login(token=\"\")" - ] - }, - { - "cell_type": "markdown", - "id": "b5384e9a-6c40-4454-abe8-413ad9d5db96", - "metadata": {}, - "source": [ - "3. **Obtain the dataset**: Generate the `wikitext-{train/val/test}.jsonl` splits after loading the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b420bd44-3628-45e2-92e7-df38f72a658a", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "# Split into train, test and val files\n", - "\n", - "import json\n", - "import os\n", - "from datasets import load_dataset\n", - "\n", - "# Load the WikiText-103 dataset\n", - "dataset = load_dataset(\"wikitext\", \"wikitext-103-v1\")\n", - "\n", - "# Define the destination folder\n", - "data_folder = 'wikitext-data'\n", - "os.makedirs(data_folder, exist_ok=True)\n", - "\n", - "# Define file paths and destination paths\n", - "file_paths = {\n", - " 'train': os.path.join(data_folder, 'wikitext-train.jsonl'),\n", - " 'validation': os.path.join(data_folder, 'wikitext-val.jsonl'),\n", - " 'test': os.path.join(data_folder, 'wikitext-test.jsonl')\n", - "}\n", - "\n", - "# Function to save dataset split to a JSONL file\n", - "def save_to_jsonl(file_path, data):\n", - " with open(file_path, 'w') as file:\n", - " for item in data:\n", - " file.write(json.dumps(item) + '\\n')\n", - "\n", - "# Define splits\n", - "splits = [\"train\", \"validation\", \"test\"]\n", - "\n", - "# Save splits to JSONL files and calculate their sizes\n", - "for split in splits:\n", - " if split in dataset:\n", - " save_to_jsonl(file_paths[split], dataset[split])\n", - " else:\n", - " print(f\"Split {split} not found in the dataset.\")\n" - ] - }, - { - "cell_type": "markdown", - "id": "0185a0a9-904d-46de-a450-db4c84c4cde4", - "metadata": { - "tags": [] - }, - "source": [ - "---\n", - "## Step-by-step instructions\n", - "\n", - "This notebook is structured into five steps:\n", - "1. Prepare the dataset\n", - "2. Finetune the teacher on the dataset\n", - "3. Prune the finetuned-teacher model to create a student\n", - "3. Distill knowledge from teacher into student\n", - "4. Display the validation loss" - ] - }, - { - "cell_type": "markdown", - "id": "cf1d41ff-2cba-4efc-84e3-7d713df0cdb8", - "metadata": {}, - "source": [ - "### Step 1: Prepare the dataset\n", - "\n", - "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2c49c1b8-2447-426c-9f24-bf5956aa2941", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-train.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_train \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "72d14fd7-702f-4b74-a6e5-af3a60eef3a9", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-test.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_test \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1338a1ce-f0e2-4151-ad3d-d34db75ea1bd", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", - "--input=\"./wikitext-data/wikitext-val.jsonl\" \\\n", - "--tokenizer-library='huggingface' \\\n", - "--tokenizer-type='meta-llama/Meta-Llama-3.1-8B' \\\n", - "--output-prefix=wikitext_tokenized_val \\\n", - "--append-eod \\\n", - "--workers=32" - ] - }, - { - "cell_type": "markdown", - "id": "eb80e212-c343-4e51-a92d-184db43df011", - "metadata": {}, - "source": [ - "After running the above scripts, you will see the preprocesed `wikitext_tokenized_{train/val/test}_text_document.{idx/bin}`files. These output files will be used in the next step." - ] - }, - { - "cell_type": "markdown", - "id": "e9f30c0a-4315-4017-b014-add4291a3fde", - "metadata": {}, - "source": [ - "\n", - "### Step 2: Finetune the teacher on the dataset\n", - "\n", - "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", - "\n", - "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c31fd642-0304-43ed-9211-041dc36f22c3", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "%%bash \n", - "\n", - "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", - "\n", - "\n", - "# Set path(s) if different:\n", - "\n", - "MODEL=\"/workspace/llama-3_1-8b-instruct-nemo_v1.0/llama3_1_8b_instruct.nemo\"\n", - "\n", - "# Can change these to accommodate resources:\n", - "\n", - "TENSOR_PARALLEL_SIZE=8\n", - "NODES=1\n", - "MICRO_BATCH_SIZE=4\n", - "\n", - "# Don't change the following:\n", - "\n", - "EXPERIMENT_DIR=\"distill_trainings\"\n", - "EXPERIMENT_NAME=\"megatron_llama_ft\"\n", - "\n", - "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", - "DATA_VAL='wikitext_tokenized_test_text_document'\n", - "DATA_TEST='wikitext_tokenized_val_text_document'\n", - "\n", - "STEPS=30\n", - "GLOBAL_BATCH_SIZE=128\n", - "\n", - "LOG_INTERVAL=1\n", - "VAL_INTERVAL=10\n", - "NUM_VAL_BATCHES=5\n", - "\n", - "LR=1e-4\n", - "MIN_LR=1e-5\n", - "WARMUP_STEPS=2\n", - "\n", - "\n", - "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", - "\n", - "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", - " --config-path /opt/NeMo/examples/nlp/language_modeling/conf/ \\\n", - " --config-name megatron_llama_distill.yaml \\\n", - " \\\n", - " name=${EXPERIMENT_NAME} \\\n", - " \\\n", - " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", - " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", - " exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \\\n", - " \\\n", - " trainer.max_steps=${STEPS} \\\n", - " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", - " trainer.val_check_interval=${VAL_INTERVAL} \\\n", - " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", - " +trainer.num_sanity_val_steps=0 \\\n", - " \\\n", - " trainer.precision=bf16 \\\n", - " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", - " trainer.num_nodes=${NODES} \\\n", - " \\\n", - " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", - " \\\n", - " model.restore_from_path=${MODEL} \\\n", - " \\\n", - " ~model.tokenizer \\\n", - " +model.tokenizer='{library: huggingface, type: meta-llama/Meta-Llama-3.1-8B, use_fast: True}' \\\n", - " \\\n", - " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", - " model.sequence_parallel=True \\\n", - " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", - " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", - " \\\n", - " model.encoder_seq_length=8192 \\\n", - " model.num_layers=32 \\\n", - " model.hidden_size=4096 \\\n", - " model.ffn_hidden_size=14336 \\\n", - " model.num_attention_heads=32 \\\n", - " model.hidden_dropout=0.0 \\\n", - " model.attention_dropout=0.0 \\\n", - " model.apply_query_key_layer_scaling=True \\\n", - " model.normalization='rmsnorm' \\\n", - " model.bias=False \\\n", - " model.activation='fast-swiglu' \\\n", - " model.position_embedding_type='rope' \\\n", - " model.share_embeddings_and_output_weights=False \\\n", - " model.num_query_groups=8 \\\n", - " ++model.scale_positional_embedding=True \\\n", - " ++model.rotary_base=500000.0 \\\n", - " \\\n", - " model.optim.name=distributed_fused_adam \\\n", - " model.optim.lr=${LR} \\\n", - " model.optim.sched.min_lr=${MIN_LR} \\\n", - " model.optim.sched.warmup_steps=${WARMUP_STEPS}" - ] - }, - { - "cell_type": "markdown", - "id": "8aaf604a-efc0-4908-9055-5cf3bb0a05ae", - "metadata": {}, - "source": [ - "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", - "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." - ] - }, - { - "cell_type": "markdown", - "id": "2709ccc0-bbb8-44ba-b00d-15b1dc5d60a7", - "metadata": {}, - "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "\n", - "The next step is to trim the last 16 layers in the finetined teacher model. In this notebook, we are using depth-pruning and would be using the [megatron_gpt_drop_layers](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_drop_layers.py) script. \n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9715a1b-7a23-437f-b5e1-feec8e6c68e0", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "!python -m torch.distributed.launch --nproc_per_node=8 \\\n", - " /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \\\n", - " --path_to_nemo \"./distill_trainings/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\" \\\n", - " --path_to_save \"/workspace/4b_trimmed_model.nemo\" \\\n", - " --tensor_model_parallel_size 8 \\\n", - " --pipeline_model_parallel_size 1 \\\n", - " --gpus_per_node 8 \\\n", - " --drop_layers 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31" - ] - }, - { - "cell_type": "markdown", - "id": "1e9553db-9478-4074-9de1-1fa01a0e835c", - "metadata": {}, - "source": [ - "Running this script will save the depth-pruned model `4b_trimmed_model.nemo` to your workspace." - ] - }, - { - "cell_type": "markdown", - "id": "b8ada696-5d77-4113-9d15-a603113fdd58", - "metadata": {}, - "source": [ - "\n", - "### Step 4: Distill knowledge from teacher into student\n", - "\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", - "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model `4b_trimmed_model.nemo`. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", - "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61c0c69d-9401-4355-8725-78aa72eee8da", - "metadata": { - "scrolled": true, - "tags": [] - }, - "outputs": [], - "source": [ - "%%bash \n", - "\n", - "export CUDA_DEVICE_MAX_CONNECTIONS=1\n", - "\n", - "\n", - "# Can change these to accommodate resources:\n", - "\n", - "TENSOR_PARALLEL_SIZE=8\n", - "NODES=1\n", - "MICRO_BATCH_SIZE=4\n", - "\n", - "# Don't change the following:\n", - "\n", - "EXPERIMENT_DIR=\"distill_trainings\"\n", - "EXPERIMENT_NAME=\"megatron_llama_distill\"\n", - "\n", - "TEACHER=\"${EXPERIMENT_DIR}/megatron_llama_ft/checkpoints/megatron_llama_ft.nemo\"\n", - "STUDENT=\"/workspace/4b_trimmed_model.nemo\"\n", - "\n", - "FINAL_MODEL_PATH=\"${EXPERIMENT_DIR}/${EXPERIMENT_NAME}/checkpoints/distilled_4b_model.nemo\"\n", - "\n", - "DATA_TRAIN='wikitext_tokenized_train_text_document'\n", - "DATA_VAL='wikitext_tokenized_test_text_document'\n", - "DATA_TEST='wikitext_tokenized_val_text_document'\n", - "\n", - "STEPS=30\n", - "GLOBAL_BATCH_SIZE=128\n", - "\n", - "LOG_INTERVAL=1\n", - "VAL_INTERVAL=10\n", - "NUM_VAL_BATCHES=5\n", - "\n", - "LR=1e-4\n", - "MIN_LR=1e-5\n", - "WARMUP_STEPS=2\n", - "\n", - "\n", - "cmd=\"torchrun --nproc-per-node=${TENSOR_PARALLEL_SIZE}\"\n", - "\n", - "${cmd} /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_distillation.py \\\n", - " name=${EXPERIMENT_NAME} \\\n", - " \\\n", - " exp_manager.exp_dir=${EXPERIMENT_DIR} \\\n", - " exp_manager.checkpoint_callback_params.save_top_k=1 \\\n", - " \\\n", - " trainer.max_steps=${STEPS} \\\n", - " trainer.log_every_n_steps=${LOG_INTERVAL} \\\n", - " trainer.val_check_interval=${VAL_INTERVAL} \\\n", - " trainer.limit_val_batches=${NUM_VAL_BATCHES} \\\n", - " +trainer.num_sanity_val_steps=0 \\\n", - " \\\n", - " trainer.precision=bf16 \\\n", - " trainer.devices=${TENSOR_PARALLEL_SIZE} \\\n", - " trainer.num_nodes=${NODES} \\\n", - " \\\n", - " \"model.data.data_prefix={train:[1.0,$DATA_TRAIN],validation:[$DATA_VAL],test:[$DATA_TEST]}\" \\\n", - " \\\n", - " model.restore_from_path=${STUDENT} \\\n", - " model.kd_teacher_restore_from_path=${TEACHER} \\\n", - " model.nemo_path=${FINAL_MODEL_PATH} \\\n", - " \\\n", - " model.tensor_model_parallel_size=${TENSOR_PARALLEL_SIZE} \\\n", - " model.sequence_parallel=True \\\n", - " model.micro_batch_size=${MICRO_BATCH_SIZE} \\\n", - " model.global_batch_size=${GLOBAL_BATCH_SIZE} \\\n", - " \\\n", - " model.optim.name=distributed_fused_adam \\\n", - " model.optim.lr=${LR} \\\n", - " model.optim.sched.min_lr=${MIN_LR} \\\n", - " model.optim.sched.warmup_steps=${WARMUP_STEPS}\n" - ] - }, - { - "cell_type": "markdown", - "id": "fe7034ba-8c69-4edb-8c0f-84fdca43c152", - "metadata": {}, - "source": [ - "This will create the final distilled model named `distilled_4b_model.nemo` in `./distill_trainings/megatron_llama_distill/checkpoints`.\n", - "> `NOTE:`This script takes at least 35 minutes to run and generate the final distilled model." - ] - }, - { - "cell_type": "markdown", - "id": "c9a66d44-5028-47f9-9df3-9f07692e9461", - "metadata": {}, - "source": [ - "### Step 5: Display the validation loss\n", - "\n", - "Now that the results are in, let's visualize the validation loss of the distilled model using the `tensorboard` library. \n", - "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "be4da14c-c03f-4c28-accd-8f676dbef8a9", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext tensorboard\n", - "%tensorboard --logdir \"distill_trainings/megatron_llama_distill/\" --port=6007" - ] - }, - { - "cell_type": "markdown", - "id": "08c63b80-0f24-4dde-b5d6-11db444726ed", - "metadata": {}, - "source": [ - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "648424fc-6a51-43ca-8f19-6ad05f949054", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from IPython.display import Image, display\n", - "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_distillation.png\", width=400))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From a2572a7bd57e6d6d7b95090a792455638740203f Mon Sep 17 00:00:00 2001 From: lilithgrigoryan <38436437+lilithgrigoryan@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:24:51 +0400 Subject: [PATCH 16/24] Beam search algorithm implementation for TDT models (#10903) * initial commit Signed-off-by: lilithgrigoryan * add: default beam search implementation Signed-off-by: lilithgrigoryan * fix: changed to removing duplicate hypothesis in separate function Signed-off-by: lilithgrigoryan * fix: changed to cartesian product in choosing best hyp Signed-off-by: lilithgrigoryan * fix: minor fixes in comments Signed-off-by: lilithgrigoryan * add: maes decoding strategy Signed-off-by: lilithgrigoryan * add: durations filtering in maes, lm fusion in progress Signed-off-by: lilithgrigoryan * fix: refactored, added comments, command line args, finalized Signed-off-by: lilithgrigoryan * fix: removed prints Signed-off-by: lilithgrigoryan * add: docs Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: minor fix Signed-off-by: lilithgrigoryan * fix: rm beam_size=1 exception, rm duplicates check, fix error handling Signed-off-by: lilithgrigoryan * fix: error handling Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: removed evaluations file Signed-off-by: lilithgrigoryan * rn: blank scoring Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * rm: blank scoring and duration beam size Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: removed durations_beam_size from default beam search Signed-off-by: lilithgrigoryan * add: logaddexp Signed-off-by: lilithgrigoryan * rm: prefix search Signed-off-by: lilithgrigoryan * rn: nested loop over extensions Signed-off-by: lilithgrigoryan * fix: bug with caching Signed-off-by: lilithgrigoryan * rm: topk on durations Signed-off-by: lilithgrigoryan * add: restored prefix search Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * fix: fixed comments Signed-off-by: lilithgrigoryan * refactored duplicate merging Signed-off-by: lilithgrigoryan * changes batch scoring Signed-off-by: lilithgrigoryan * refactored rnnt batch scoring Signed-off-by: lilithgrigoryan * alsd first working Signed-off-by: lilithgrigoryan * refactored Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * remove stacking operations Signed-off-by: lilithgrigoryan * fixes im base class Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * remove potentially uninitialized local variable Signed-off-by: lilithgrigoryan * default beam search minor fixes Signed-off-by: lilithgrigoryan * add test, fix maes timesteps Signed-off-by: lilithgrigoryan * rm file Signed-off-by: lilithgrigoryan * rm file Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * fix comments Signed-off-by: lilithgrigoryan * add ngram lm test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix maes_num_steps=1 Signed-off-by: lilithgrigoryan * fix kenlm model path Signed-off-by: lilithgrigoryan * fix kenlm model full path Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * made requested changes Signed-off-by: lilithgrigoryan * merge after isort Signed-off-by: lilithgrigoryan * add prints to test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add Kenlm to asr requirements Signed-off-by: lilithgrigoryan * remove prints in tests Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add kenlm to test requirements Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm kenlm from link, add package-name Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm second kenlm installation Signed-off-by: lilithgrigoryan * rm kenlm from dependencies make test optional Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix in test Signed-off-by: lilithgrigoryan * fix in test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix comments Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * splitted docstrings Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * splitted docstrings Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fixes to python3 type annotations Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * merging Signed-off-by: lilithgrigoryan * merging Signed-off-by: lilithgrigoryan * fix in return type Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * rm time_idx Signed-off-by: lilithgrigoryan * fix comments to python3 style Signed-off-by: lilithgrigoryan --------- Signed-off-by: lilithgrigoryan Signed-off-by: lilithgrigoryan Co-authored-by: lilithgrigoryan Co-authored-by: lilithgrigoryan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/asr/api.rst | 15 + .../parts/submodules/rnnt_beam_decoding.py | 56 +- .../asr/parts/submodules/rnnt_decoding.py | 227 +++-- .../parts/submodules/rnnt_greedy_decoding.py | 54 +- .../asr/parts/submodules/tdt_beam_decoding.py | 800 ++++++++++++++++++ .../collections/asr/parts/utils/rnnt_utils.py | 12 +- .../asr/decoding/test_rnnt_decoding.py | 88 +- 7 files changed, 1155 insertions(+), 97 deletions(-) create mode 100644 nemo/collections/asr/parts/submodules/tdt_beam_decoding.py diff --git a/docs/source/asr/api.rst b/docs/source/asr/api.rst index c99d92c0371a..a35ea49ea2c4 100644 --- a/docs/source/asr/api.rst +++ b/docs/source/asr/api.rst @@ -276,6 +276,21 @@ RNNT Decoding :show-inheritance: :members: +TDT Decoding +~~~~~~~~~~~~~ + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer + :show-inheritance: + :members: + Hypotheses ~~~~~~~~~~ diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index c01f2363db75..e0bd47bb8ce0 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -55,6 +55,20 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + + Returns: + list: A list of packed hypotheses in tensor format. + """ for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long) @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: def _states_to_device(dec_state, device='cpu'): + """ + Transfers decoder states to the specified device. + + This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda'). + + Args: + dec_state (Tensor): The decoder states to be transferred. + device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'. + + Returns: + Tensor: The decoder states on the specified device. + """ if torch.is_tensor(dec_state): dec_state = dec_state.to(device) @@ -106,7 +132,8 @@ class BeamRNNTInfer(Typing): however the time required for the search also grows steadily. `tsd` - time synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. @@ -114,7 +141,8 @@ class BeamRNNTInfer(Typing): good results. This also requires greater memory to execute. `alsd` - alignment-length synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing): For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. `maes` = modified adaptive expansion searcn. Please refer to the paper: - [Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505) + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing): and affects the speed of inference since large values will perform large beam search in the next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing): preserve_alignments: Bool flag which preserves the history of alignments generated during beam decoding (sample). When set to true, the Hypothesis will contain - the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1). + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu return lm_score, next_state def set_decoding_type(self, decoding_type: str): - - # Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ # TOKEN_OFFSET for BPE-based models if decoding_type == 'subword': from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str): @dataclass class BeamRNNTInferConfig: + """ + Beam RNNT Inference config. + """ + beam_size: int search_type: str = 'default' score_norm: bool = True diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index da280a0c6b3c..d3a63467c485 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -23,7 +23,7 @@ import torch from omegaconf import OmegaConf -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding, tdt_beam_decoding from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMixin from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer @@ -67,15 +67,15 @@ class AbstractRNNTDecoding(ConfidenceMixin): rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated. Can take the following values - "char" for character/subword time stamps, "word" for word level - time stamps, "segment" for segment level time stamps and "all" (default), for character, - word and segment level time stamps. + time stamps, "segment" for segment level time stamps and "all" (default), for character, word and + segment level time stamps. word_seperator: Str token representing the seperator between words. segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary - for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -106,8 +106,8 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -179,23 +179,23 @@ class AbstractRNNTDecoding(ConfidenceMixin): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep + this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the - next step. + and affects the speed of inference since large values will perform large beam search in the next + step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. The default (2.3) is selected from the paper. It performs a comparison - (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set - and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin - of additional tokens which can be potential candidates for expansion apart from the "most likely" + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions - (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). - This is a hyper parameter to be experimentally tuned on a validation set. + (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is + a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -234,8 +234,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None and self.big_blank_durations != []: raise ValueError("duration and big_blank_durations can't both be not None") - if self.cfg.strategy not in ['greedy', 'greedy_batch']: - raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") + if self.cfg.strategy not in ['greedy', 'greedy_batch', 'beam', 'maes']: + raise ValueError( + "currently only greedy, greedy_batch, beam and maes inference is supported for TDT models" + ) if ( self.big_blank_durations is not None and self.big_blank_durations != [] @@ -386,20 +388,32 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'beam': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='default', - score_norm=self.cfg.beam.get('score_norm', True), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ) + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) elif self.cfg.strategy == 'tsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -413,7 +427,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'alsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -427,26 +440,44 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'maes': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='maes', - score_norm=self.cfg.beam.get('score_norm', True), - maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), - maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), - maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), - maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), - ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), - hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), - hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), - ) - + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), + hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), + hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.3), + ) else: raise ValueError( @@ -728,6 +759,15 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: raise NotImplementedError() def update_joint_fused_batch_size(self): + """ " + Updates the fused batch size for the joint module if applicable. + + If `joint_fused_batch_size` is set, verifies that the joint module has + the required `set_fused_batch_size` and `set_fuse_loss_wer` functions. + If present, updates the batch size; otherwise, logs a warning. + + If `joint_fused_batch_size` is <= 0, disables fused batch processing. + """ if self.joint_fused_batch_size is None: # do nothing and let the Joint itself handle setting up of the fused batch return @@ -754,6 +794,21 @@ def update_joint_fused_batch_size(self): self.decoding.joint.set_fuse_loss_wer(False) def compute_rnnt_timestamps(self, hypothesis: Hypothesis, timestamp_type: str = "all"): + """ + Computes character, word, and segment timestamps for an RNN-T hypothesis. + + This function generates timestamps for characters, words, and segments within + a hypothesis sequence. The type of timestamps computed depends on `timestamp_type`, + which can be 'char', 'word', 'segment', or 'all'. + + Args: + hypothesis (Hypothesis): Hypothesis. + timestamp_type (str): Type of timestamps to compute. Options are 'char', 'word', 'segment', or 'all'. + Defaults to 'all'. + + Returns: + Hypothesis: The updated hypothesis with computed timestamps for characters, words, and/or segments. + """ assert timestamp_type in ['char', 'word', 'segment', 'all'] # Unpack the temporary storage @@ -890,7 +945,7 @@ def _compute_offsets( # Construct the start and end indices brackets end_indices = np.asarray(token_repetitions).cumsum() - start_indices = np.concatenate(([int(start_index)], end_indices[:-1])) + start_indices = np.concatenate(([start_index], end_indices[:-1])) # Process the TxU dangling alignment tensor, containing pairs of (logits, label) alignment_labels = [al_logits_labels for al_logits_labels in hypothesis.text[1]] @@ -953,8 +1008,8 @@ def _refine_timestamps_tdt( # Check if token is a punctuation mark # If so, set its start and end offset as start and end of the previous token - # This is done because there was observed a behaviour, when punctuation marks are predicted long - # after preceding token (i.e. after silence) + # This is done because there was observed a behaviour, when punctuation marks are + # predicted long after preceding token (i.e. after silence) if offset['char'][0] in supported_punctuation and i > 0: encoded_char_offsets[i]['start_offset'] = offset['start_offset'] = char_offsets[i - 1]['end_offset'] encoded_char_offsets[i]['end_offset'] = offset['end_offset'] = offset['start_offset'] @@ -1114,7 +1169,8 @@ def _get_segment_offsets( offsets: A list of dictionaries, each containing "word", "start_offset" and "end_offset". segments_delimiter_tokens: List containing tokens representing the seperator(s) between segments. supported_punctuation: Set containing punctuation marks in the vocabulary. - segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain text. + segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain + text. Returns: A list of dictionaries containing the segment offsets. Each item contains "segment", "start_offset" and "end_offset". @@ -1242,9 +1298,10 @@ class RNNTDecoding(AbstractRNNTDecoding): exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word - confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + confidence. + Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1331,7 +1388,7 @@ class RNNTDecoding(AbstractRNNTDecoding): and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, @@ -1339,8 +1396,7 @@ class RNNTDecoding(AbstractRNNTDecoding): next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the - expansions. - The default (2.3) is selected from the paper. It performs a comparison + expansions. The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the "most likely" @@ -1382,7 +1438,9 @@ def __init__( supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('char') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1498,8 +1556,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for - forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -1530,8 +1588,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be - calculated and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1602,7 +1660,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): at increased cost to execution time. alsd_max_target_len: optional int or float, determines the potential maximum target sequence - length. If an integer is provided, it can decode sequences of that particular maximum length. + length.If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1622,16 +1680,15 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): and affects the speed of inference since large values will perform large beam search in the next step. - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when - computing the expansions. The default (2.3) is selected from the paper. It performs a - comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the - Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore - provides a margin of additional tokens which can be potential candidates for expansion - apart from the "most likely" candidate. Lower values will reduce the number of expansions - (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher - values will increase the number of expansions (by reducing pruning-by-value, thereby - reducing speed but potentially improving accuracy). This is a hyper parameter to be - experimentally tuned on a validation set. + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of + expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving + accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1658,7 +1715,9 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('subword') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1759,8 +1818,8 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) else: logging.warning( - "Ignoring request for lang output in hypotheses since the model does not use an aggregate\ - tokenizer" + "Ignoring request for lang output in hypotheses since the model does not use an aggregate \ + tokenizer" ) return hypotheses @@ -1768,6 +1827,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp @dataclass class RNNTDecodingConfig: + """ + RNNT Decoding config + """ + model_type: str = "rnnt" # one of "rnnt", "multiblank" or "tdt" strategy: str = "greedy_batch" @@ -1825,4 +1888,8 @@ class RNNTDecodingConfig: @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig): + """ + RNNT BPE Decoding Config + """ + pass diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index f9cf368fe405..bd169d0d224e 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -49,7 +49,20 @@ def pack_hypotheses( hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor, ) -> List[rnnt_utils.Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + Returns: + list: A list of packed hypotheses in tensor format. + """ if hasattr(logitlen, 'cpu'): logitlen_cpu = logitlen.to('cpu') else: @@ -578,7 +591,8 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls to prediction network (with maximum possible batch size), which makes it especially useful for scaling the prediction network. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( @@ -1169,6 +1183,10 @@ def _greedy_decode_masked( class ExportedModelGreedyBatchedRNNTInfer: + """ + Exported Model Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = None): self.encoder_model_path = encoder_model self.decoder_joint_model_path = decoder_joint_model @@ -1344,9 +1362,25 @@ def _setup_blank_index(self): raise NotImplementedError() def run_encoder(self, audio_signal, length): + """ + Runs encoder network: + + Args: + audio_signal: audio signal + length: audio length + """ raise NotImplementedError() def run_decoder_joint(self, enc_logits, targets, target_length, *states): + """ + Runs decoder joint networks. + + Args: + enc_logits: encoder logits + targets: targets + target_length: target length + states: states + """ raise NotImplementedError() def _get_initial_states(self, batchsize): @@ -1354,6 +1388,10 @@ def _get_initial_states(self, batchsize): class ONNXGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + ONNX Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = 10): super().__init__( encoder_model=encoder_model, @@ -1433,7 +1471,8 @@ def _setup_blank_index(self): self._blank_index = log_probs.shape[-1] - 1 # last token of vocab size is blank token logging.info( - f"Enc-Dec-Joint step was evaluated, blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" + f"Enc-Dec-Joint step was evaluated, \ + blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" ) def run_encoder(self, audio_signal, length): @@ -1512,6 +1551,10 @@ def _get_initial_states(self, batchsize): class TorchscriptGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + Torchscript Greedy Batched RNNT Infer + """ + def __init__( self, encoder_model: str, @@ -2336,6 +2379,8 @@ def _greedy_decode_masked( @dataclass class GreedyRNNTInferConfig: + """Greedy RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2354,6 +2399,8 @@ def __post_init__(self): @dataclass class GreedyBatchedRNNTInferConfig: + """Greedy Batched RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2708,7 +2755,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py new file mode 100644 index 000000000000..908fc1c13d19 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -0,0 +1,800 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import pack_hypotheses +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType +from nemo.utils import logging + +try: + import kenlm + + KENLM_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + KENLM_AVAILABLE = False + + +class BeamTDTInfer(Typing): + """ + Beam search implementation for Token-andDuration Transducer (TDT) models. + + Sequence level beam decoding or batched-beam decoding, performed auto-repressively + depending on the search type chosen. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + durations: list of duration values from TDT model. + + beam_size: number of beams for beam search. Must be a positive integer >= 1. + If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. + + search_type: str representing the type of beam search to perform. + Must be one of ['beam', 'maes']. + + Algorithm used: + + `default` - basic beam search strategy. Larger beams generally result in better decoding, + however the time required for the search also grows steadily. + + `maes` = modified adaptive expansion search. Please refer to the paper: + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) + + Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the + number of expansions (for tokens) required per timestep. The number of expansions can usually + be constrained to 1 or 2, and in most cases 2 is sufficient. + + This beam search technique can possibly obtain superior WER while sacrificing some evaluation time. + + score_norm: bool, whether to normalize the scores of the log probabilities. + + return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), + or return all N hypothesis (sorted with best score first). The container class changes based + this flag - + When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + + # The following arguments are specific to the chosen `search_type` + + # mAES flags + maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient. int > 1. + + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 + in order to reduce expensive beam search cost later. int >= 0. + + maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. + Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, + and affects the speed of inference since large values will perform large beam search in the next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. + Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed + but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, + thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally + tuned on a validation set. + + softmax_temperature: Scales the logits of the joint prior to computing log_softmax. + + preserve_alignments: Bool flag which preserves the history of alignments generated during + beam decoding (sample). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) + + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. + U is the number of target tokens for the current timestep Ti. + + NOTE: `preserve_alignments` is an invalid argument for any `search_type` + other than basic beam search. + + ngram_lm_model: str + The path to the N-gram LM. + ngram_lm_alpha: float + Alpha weight of N-gram LM. + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "partial_hypotheses": [NeuralType(elements_type=HypothesisType(), optional=True)], # must always be last + } + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + durations: list, + beam_size: int, + search_type: str = 'default', + score_norm: bool = True, + return_best_hypothesis: bool = True, + maes_num_steps: int = 2, + maes_prefix_alpha: int = 1, + maes_expansion_gamma: float = 2.3, + maes_expansion_beta: int = 2, + softmax_temperature: float = 1.0, + preserve_alignments: bool = False, + ngram_lm_model: Optional[str] = None, + ngram_lm_alpha: float = 0.3, + ): + self.joint = joint_model + self.decoder = decoder_model + self.durations = durations + + self.token_offset = 0 + self.search_type = search_type + self.blank = decoder_model.blank_idx + self.vocab_size = decoder_model.vocab_size + self.return_best_hypothesis = return_best_hypothesis + + self.beam_size = beam_size + self.score_norm = score_norm + self.max_candidates = beam_size + self.softmax_temperature = softmax_temperature + self.preserve_alignments = preserve_alignments + + if preserve_alignments: + raise ValueError("Alignment preservation has not been implemented.") + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + if self.preserve_alignments: + raise NotImplementedError("Preserving alignments is not implemented.") + + if search_type == "default": + if self.beam_size == 1: + logging.info( + """If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyTDTInfer or GreedyBatchedTDTInfer.""" + ) + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + raise NotImplementedError("`tsd` (Time Synchronous Decoding) has not been implemented.") + elif search_type == "alsd": + raise NotImplementedError("`alsd` (Alignment Length Synchronous Decoding) has not been implemented.") + elif search_type == "nsc": + raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") + elif search_type == "maes": + self.search_algorithm = self.modified_adaptive_expansion_search + else: + raise NotImplementedError( + f"The search type ({search_type}) supplied is not supported!\n" f"Please use one of : (default, maes)" + ) + + if self.search_type == 'maes': + self.maes_num_steps = int(maes_num_steps) + self.maes_prefix_alpha = int(maes_prefix_alpha) + self.maes_expansion_beta = int(maes_expansion_beta) + self.maes_expansion_gamma = float(maes_expansion_gamma) + + self.max_candidates += maes_expansion_beta + + if self.maes_prefix_alpha < 0: + raise ValueError("`maes_prefix_alpha` must be a positive integer.") + + if self.vocab_size < beam_size + maes_expansion_beta: + raise ValueError( + f"beam_size ({beam_size}) + expansion_beta ({maes_expansion_beta}) " + f"should be smaller or equal to vocabulary size ({self.vocab_size})." + ) + + if self.maes_num_steps < 1: + raise ValueError("`maes_num_steps` must be greater than 0.") + + try: + self.zero_duration_idx = self.durations.index(0) + except ValueError: + self.zero_duration_idx = None + self.min_non_zero_duration_idx = int( + np.argmin(np.ma.masked_where(np.array(self.durations) == 0, self.durations)) + ) + + if ngram_lm_model: + if search_type != "maes": + raise ValueError("For decoding with language model `maes` decoding strategy must be chosen.") + + if KENLM_AVAILABLE: + self.ngram_lm = kenlm.Model(ngram_lm_model) + self.ngram_lm_alpha = ngram_lm_alpha + else: + raise ImportError( + "KenLM package (https://github.com/kpu/kenlm) is not installed. " "Use ngram_lm_model=None." + ) + else: + self.ngram_lm = None + + @typecheck() + def __call__( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: tuple[list[Hypothesis | NBestHypotheses],] = None, + ) -> tuple[list[Hypothesis | NBestHypotheses],]: + """Perform general beam search. + + Args: + encoder_output: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + + Returns: + Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, + otherwise a list containing a single NBestHypotheses, which itself contains a list of + Hypothesis. This list is sorted such that the best hypothesis is the first element. + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.inference_mode(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + with tqdm( + range(encoder_output.size(0)), + desc='Beam search progress:', + total=encoder_output.size(0), + unit='sample', + ) as idx_gen: + + _p = next(self.joint.parameters()) + dtype = _p.dtype + + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] + + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) + + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis + + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) + + # Pack the result + if self.return_best_hypothesis: + best_hypothesis: Hypothesis = nbest_hyps[0] + else: + best_hypothesis: NBestHypotheses = NBestHypotheses(nbest_hyps) + hypotheses.append(best_hypothesis) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (hypotheses,) + + def default_beam_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """Default Beam search implementation for TDT models. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypoteses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + durations_beam_k = min(beam, len(self.durations)) + + # Initialize zero vector states. + decoder_state = self.decoder.initialize_state(encoder_outputs) + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Initialize hypothesis array with blank hypothesis. + start_hyp = Hypothesis( + score=0.0, y_sequence=[self.blank], dec_state=decoder_state, timestep=[-1], length=0, last_frame=0 + ) + kept_hyps = [start_hyp] + + for time_idx in range(int(encoded_lengths)): + # Retrieve hypotheses for current and future frames + hyps = [hyp for hyp in kept_hyps if hyp.last_frame == time_idx] # hypotheses for current frame + kept_hyps = [hyp for hyp in kept_hyps if hyp.last_frame > time_idx] # hypothesis for future frames + + # Loop over hypotheses of current frame + while len(hyps) > 0: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + # Update decoder state and get probability distribution over vocabulary and durations. + encoder_output = encoder_outputs[:, time_idx : time_idx + 1, :] # [1, 1, D] + decoder_output, decoder_state, _ = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] + logits = ( + self.joint.joint(encoder_output, decoder_output) / self.softmax_temperature + ) # [1, 1, 1, V + NUM_DURATIONS + 1] + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) # [V + 1] + durations_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) # [NUM_DURATIONS] + + # Proccess non-blank tokens + # Retrieve the top `beam_k` most probable tokens and the top `duration_beam_k` most probable durations. + # Then, select the top `beam_k` pairs of (token, duration) based on the highest combined probabilities. + # Note that indices are obtained in the flattened array. + logp_topks, logp_topk_idxs = logp[:-1].topk(beam_k, dim=-1) # topk of tokens without blank token + durations_logp_topks, durations_logp_topk_idxs = durations_logp.topk(durations_beam_k, dim=-1) + total_logp_topks, total_logp_topk_idxs = ( + torch.cartesian_prod(durations_logp_topks, logp_topks).sum(dim=-1).topk(beam_k, dim=-1) + ) + + # Loop over pairs of (token, duration) with highest combined log prob + for total_logp_topk, total_logp_topk_idx in zip(total_logp_topks, total_logp_topk_idxs): + # Restore indices from flattened array indices + token_idx = int(logp_topk_idxs[total_logp_topk_idx % beam_k]) + duration_idx = int(durations_logp_topk_idxs[total_logp_topk_idx // beam_k]) + + duration = self.durations[duration_idx] + # Construct hypothesis for non-blank token + new_hyp = Hypothesis( + score=float(max_hyp.score + total_logp_topk), # update score + y_sequence=max_hyp.y_sequence + [token_idx], # update hypothesis sequence + dec_state=decoder_state, # update decoder state + timestep=max_hyp.timestep + [time_idx + duration], # update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + + # Update current frame hypotheses if duration is zero and future frame hypotheses otherwise + if duration == 0: + hyps.append(new_hyp) + else: + kept_hyps.append(new_hyp) + + # Update future frames with blank tokens + # Note: blank token can have only non-zero duration + for duration_idx in durations_logp_topk_idxs: + duration_idx = int(duration_idx) + # If zero is the only duration in topk, switch to closest non-zero duration to continue + if duration_idx == self.zero_duration_idx: + if durations_logp_topk_idxs.shape[0] == 1: + duration_idx = self.min_non_zero_duration_idx + else: + continue + + duration = self.durations[duration_idx] + new_hyp = Hypothesis( + score=float(max_hyp.score + logp[self.blank] + durations_logp[duration_idx]), # update score + y_sequence=max_hyp.y_sequence[:], # no need to update sequence + dec_state=max_hyp.dec_state, # no need to update decoder state + timestep=max_hyp.timestep[:], # no need to update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + kept_hyps.append(new_hyp) + + # Merge duplicate hypotheses. + # If two consecutive blank tokens are predicted and their duration values sum up to the same number, + # it will produce two hypotheses with the same token sequence but different scores. + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + + if len(hyps) > 0: + # Keep those hypothesis that have scores greater than next search generation + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + # If enough hypotheses have scores greater than next search generation, + # stop beam search. + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + else: + # If there are no hypotheses in a current frame, + # keep only `beam` best hypotheses for the next search generation. + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + return self.sort_nbest(kept_hyps) + + def modified_adaptive_expansion_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """ + Modified Adaptive Exoansion Search algorithm for TDT models. + Based on/modified from https://ieeexplore.ieee.org/document/9250505. + Supports N-gram language model shallow fusion. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypotheses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_state = self.decoder.initialize_state( + torch.zeros(1, device=encoder_outputs.device, dtype=encoder_outputs.dtype) + ) # [L, B, H], [L, B, H] for LSTMS + + # Initialize first hypothesis for the beam (blank). + start_hyp = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=self.decoder.batch_select_state(beam_state, 0), + timestep=[-1], + length=0, + last_frame=0, + ) + init_tokens = [start_hyp] + + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis(init_tokens, cache) + state = beam_state[0] + + # Initialize first hypothesis for the beam (blank) for kept hypotheses + start_hyp_kept = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=state, + dec_out=[beam_decoder_output[0]], + timestep=[-1], + length=0, + last_frame=0, + ) + + kept_hyps = [start_hyp_kept] + + # Setup ngram LM: + if self.ngram_lm: + init_lm_state = kenlm.State() + self.ngram_lm.BeginSentenceWrite(init_lm_state) + start_hyp_kept.ngram_lm_state = init_lm_state + + for time_idx in range(encoded_lengths): + # Select current iteration hypotheses + hyps = [x for x in kept_hyps if x.last_frame == time_idx] + kept_hyps = [x for x in kept_hyps if x.last_frame > time_idx] + + if len(hyps) == 0: + continue + + beam_encoder_output = encoder_outputs[:, time_idx : time_idx + 1] # [1, 1, D] + # Perform prefix search to update hypothesis scores. + if self.zero_duration_idx is not None: + hyps = self.prefix_search( + sorted(hyps, key=lambda x: len(x.y_sequence), reverse=True), + beam_encoder_output, + prefix_alpha=self.maes_prefix_alpha, + ) + + list_b = [] # List that contains the blank token emissions + list_nb = [] # List that contains the non-zero duration non-blank token emissions + # Repeat for number of mAES steps + for n in range(self.maes_num_steps): + # Pack the decoder logits for all current hypotheses + beam_decoder_output = torch.stack([h.dec_out[-1] for h in hyps]) # [H, 1, D] + + # Extract the log probabilities + beam_logits = self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + + # Retrieve the top `max_candidades` most probable tokens. + # Then, select the top `max_candidates` pairs of (token, duration) + # based on the highest combined probabilities. + # Note that indices are obtained in flattened array. + beam_logp_topks, beam_idx_topks = beam_logp.topk(self.max_candidates, dim=-1) + beam_total_logp = (beam_duration_logp[:, :, None] + beam_logp_topks[:, None, :]).view( + len(hyps), -1 + ) # [B, MAX_CANDIDATES*DURATION_BEAM] + beam_total_logp_topks, beam_total_logp_topk_idxs = beam_total_logp.topk( + self.max_candidates, dim=-1 + ) # [B, MAX_CANDIDATES] + + # Prune hypothesis to obtain k expansions + beam_best_expansion_scores = beam_total_logp_topks.max(dim=-1, keepdim=True).values + beam_masks = beam_total_logp_topks >= beam_best_expansion_scores - self.maes_expansion_gamma + beam_kexpansions_idxs = [ + sum_logp_topk_idxs[mask] for sum_logp_topk_idxs, mask in zip(beam_total_logp_topk_idxs, beam_masks) + ] + + list_exp = [] # List that contains the hypothesis expansion + list_nb_exp = [] # List that contains the hypothesis expansion + for hyp_idx, hyp in enumerate(hyps): # For all hypothesis + for idx in beam_kexpansions_idxs[hyp_idx]: # For all expansions within this hypothesis + # Restore indices in logp and durations_logp arrays from flattened indices. + k = int(beam_idx_topks[hyp_idx][idx % self.max_candidates]) + duration = self.durations[int(idx // self.max_candidates)] + total_logp = float(beam_total_logp[hyp_idx][idx]) + + # Forcing blank token to have non-zero duration + if k == self.blank and duration == 0: + duration = self.durations[self.min_non_zero_duration_idx] + + new_hyp = Hypothesis( + score=hyp.score + total_logp, + y_sequence=hyp.y_sequence[:], + dec_out=hyp.dec_out[:], + dec_state=hyp.dec_state, + timestep=hyp.timestep[:], + length=time_idx, + last_frame=hyp.last_frame + duration, + ) + + if self.ngram_lm: + new_hyp.ngram_lm_state = hyp.ngram_lm_state + + # If the expansion was for blank + if k == self.blank: + list_b.append(new_hyp) + else: + new_hyp.y_sequence.append(k) + new_hyp.timestep.append(time_idx + duration) + + if self.ngram_lm: + lm_score, new_hyp.ngram_lm_state = self.compute_ngram_score(hyp.ngram_lm_state, int(k)) + new_hyp.score += self.ngram_lm_alpha * lm_score + + # If token duration is 0 adding to expansions list + if duration == 0: + list_exp.append(new_hyp) + else: + list_nb_exp.append(new_hyp) + + # Update states for hypothesis that do not end with blank + hyps_to_update = list_nb_exp + list_exp + if len(hyps_to_update) > 0: + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis( + hyps_to_update, + cache, + ) + for hyp_idx, hyp in enumerate(hyps_to_update): + # Preserve the decoder logits for the current beam + hyp.dec_out.append(beam_decoder_output[hyp_idx]) + hyp.dec_state = beam_state[hyp_idx] + + # If there were no token expansions in any of the hypotheses, + # Early exit + list_nb += list_nb_exp + if not list_exp: + kept_hyps = kept_hyps + list_b + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + break + else: + # If this isn't the last mAES step + if n < (self.maes_num_steps - 1): + # Copy the expanded hypothesis for the next iteration + hyps = self.merge_duplicate_hypotheses(list_exp) + else: + # If this is the last mAES step add probabilities of the blank token to the end. + # Extract the log probabilities + beam_decoder_output = torch.stack([h.dec_out[-1] for h in list_exp]) # [H, 1, D] + beam_logits = ( + self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + ) + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + + # Get most probable durations + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + _, beam_max_duration_idx = torch.max(beam_duration_logp, dim=-1) + + # For all expansions, add the score for the blank label + for hyp_idx, hyp in enumerate(list_exp): + # If zero duration was obtained, change to the closest non-zero duration + duration_idx = int(beam_max_duration_idx[hyp_idx]) + if duration_idx == self.zero_duration_idx: + duration_idx = self.min_non_zero_duration_idx + + total_logp = float( + beam_logp[hyp_idx, self.blank] + beam_duration_logp[hyp_idx, duration_idx] + ) + hyp.score += total_logp + hyp.last_frame += self.durations[duration_idx] + + # Finally, update the kept hypothesis of sorted top Beam candidates + kept_hyps = kept_hyps + list_b + list_exp + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + # Sort the hypothesis with best scores + return self.sort_nbest(kept_hyps) + + def merge_duplicate_hypotheses(self, hypotheses): + """ + Merges hypotheses with identical token sequences and lengths. + The combined hypothesis's probability is the sum of the probabilities of all duplicates. + Duplicate hypotheses occur when two consecutive blank tokens are predicted + and their duration values sum up to the same number. + + Args: + hypotheses: list of hypotheses. + + Returns: + hypotheses: list if hypotheses without duplicates. + """ + sorted_hyps = sorted(hypotheses, key=lambda x: x.score, reverse=True) + kept_hyps = {} + for hyp in sorted_hyps: + hyp_key = (tuple(hyp.y_sequence), int(hyp.last_frame)) + if hyp_key in kept_hyps: + kept_hyp = kept_hyps[hyp_key] + kept_hyp.score = float(torch.logaddexp(torch.tensor(kept_hyp.score), torch.tensor(hyp.score))) + else: + kept_hyps[hyp_key] = hyp + return list(kept_hyps.values()) + + def set_decoding_type(self, decoding_type: str): + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ + # TOKEN_OFFSET for BPE-based models + if decoding_type == 'subword': + from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET + + self.token_offset = DEFAULT_TOKEN_OFFSET + + def prefix_search( + self, hypotheses: List[Hypothesis], encoder_output: torch.Tensor, prefix_alpha: int + ) -> List[Hypothesis]: + """ + Performs a prefix search and updates the scores of the hypotheses in place. + Based on https://arxiv.org/pdf/1211.3711.pdf. + + Args: + hypotheses: a list of hypotheses sorted by the length from the longest to the shortest. + encoder_output: encoder output. + prefix_alpha: maximum allowable length difference between hypothesis and a prefix. + + Returns: + hypotheses: list of hypotheses with updated scores. + """ + # Iterate over hypotheses. + for curr_idx, curr_hyp in enumerate(hypotheses[:-1]): + # For each hypothesis, iterate over the subsequent hypotheses. + # If a hypothesis is a prefix of the current one, update current score. + for pref_hyp in hypotheses[(curr_idx + 1) :]: + curr_hyp_length = len(curr_hyp.y_sequence) + pref_hyp_length = len(pref_hyp.y_sequence) + + if ( + is_prefix(curr_hyp.y_sequence, pref_hyp.y_sequence) + and (curr_hyp_length - pref_hyp_length) <= prefix_alpha + ): + # Compute the score of the first token + # that follows the prefix hypothesis tokens in current hypothesis. + # Use the decoder output, which is stored in the prefix hypothesis. + logits = self.joint.joint(encoder_output, pref_hyp.dec_out[-1]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score = pref_hyp.score + float( + logp[curr_hyp.y_sequence[pref_hyp_length]] + duration_logp[self.zero_duration_idx] + ) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + pref_hyp.ngram_lm_state, int(curr_hyp.y_sequence[pref_hyp_length]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + for k in range(pref_hyp_length, (curr_hyp_length - 1)): + # Compute the score of the next token. + # Approximate decoder output with the one that is stored in current hypothesis. + logits = self.joint.joint(encoder_output, curr_hyp.dec_out[k]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score += float(logp[curr_hyp.y_sequence[k + 1]] + duration_logp[self.zero_duration_idx]) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + next_state, int(curr_hyp.y_sequence[k + 1]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + # Update current hypothesis score + curr_hyp.score = np.logaddexp(curr_hyp.score, curr_score) + return hypotheses + + def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tuple[float, "kenlm.State"]: + """ + Computes the score for KenLM Ngram language model. + + Args: + current_lm_state: current state of the KenLM language model. + label: next label. + + Returns: + lm_score: score for `label`. + """ + if self.token_offset: + label = chr(label + self.token_offset) + else: + label = str(label) + + next_state = kenlm.State() + lm_score = self.ngram_lm.BaseScore(current_lm_state, label, next_state) + lm_score *= 1.0 / np.log10(np.e) + + return lm_score, next_state + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + """ + if self.score_norm: + return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) + else: + return sorted(hyps, key=lambda x: x.score, reverse=True) diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 76e9da6087ed..8d2755fcc0ae 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -85,6 +85,8 @@ class Hypothesis: tokens: (Optional) A list of decoded tokens (can be characters or word-pieces. last_token (Optional): A token or batch of tokens which was predicted in the last step. + + last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction. """ score: float @@ -105,6 +107,7 @@ class Hypothesis: tokens: Optional[Union[List[int], torch.Tensor]] = None last_token: Optional[torch.Tensor] = None token_duration: Optional[List[int]] = None + last_frame: Optional[int] = None @property def non_blank_frame_confidence(self) -> List[float]: @@ -244,7 +247,8 @@ def __init__( Args: batch_size: batch size for hypotheses - init_length: initial estimate for the length of hypotheses (if the real length is higher, tensors will be reallocated) + init_length: initial estimate for the length of hypotheses (if the real length is higher, + tensors will be reallocated) device: device for storing hypotheses float_dtype: float type for scores """ @@ -274,6 +278,9 @@ def __init__( self._ones_batch = torch.ones_like(self._batch_indices) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.transcript.fill_(0) self.timesteps.fill_(0) @@ -497,6 +504,9 @@ def __init__( self._batch_indices = torch.arange(batch_size, device=device) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.timesteps.fill_(0) self.logits.fill_(0.0) diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 82b5d00bede6..b5250ad5f144 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -22,8 +22,9 @@ from nemo.collections.asr.models import ASRModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint from nemo.collections.asr.parts.mixins import mixins -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.submodules import tdt_beam_decoding from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding, RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils from nemo.core.utils import numba_utils @@ -166,6 +167,39 @@ def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecodi assert len(hyp.timestep['segment']) == segments_count +def check_beam_decoding(test_data_dir, beam_config): + beam_size = beam_config.pop("beam_size", 1) + model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'nvidia/parakeet-tdt_ctc-110m') + + model_config = model.to_config_dict() + durations = list(model_config["model_defaults"]["tdt_durations"]) + + beam = tdt_beam_decoding.BeamTDTInfer( + model.decoder, + model.joint, + beam_size=beam_size, + return_best_hypothesis=False, + durations=durations, + **beam_config, + ) + + enc_out = encoded + enc_len = encoded_len + + with torch.no_grad(): + hyps: rnnt_utils.Hypothesis = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0] + _, all_hyps = decode_text_from_nbest_hypotheses(hyps, model.decoding) + all_hyps = all_hyps[0] + + print("Beam search algorithm :", beam_config['search_type']) + for idx, hyp_ in enumerate(all_hyps): + print("Hyp index", idx + 1, "text :", hyp_.text) + + assert len(hyp_.timestep) > 0 + print("Timesteps", hyp_.timestep) + print() + + class TestRNNTDecoding: @pytest.mark.unit def test_constructor(self): @@ -312,10 +346,10 @@ def test_batched_greedy_decoding_preserve_alignments(self, test_data_dir, loop_l {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "beam_size": 2}, ], ) - def test_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): + def test_rnnt_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): beam_size = beam_config.pop("beam_size", 1) model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small') - beam = beam_decode.BeamRNNTInfer( + beam = rnnt_beam_decoding.BeamRNNTInfer( model.decoder, model.joint, beam_size=beam_size, @@ -442,3 +476,51 @@ def test_char_decoding_compute_timestamps(self, test_data_dir, decoding_strategy hyps, _ = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True) check_char_timestamps(hyps[0], decoding) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "default", + "beam_size": 2, + }, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "beam_size": 2}, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 1, "beam_size": 4}, + ], + ) + def test_tdt_beam_decoding(self, test_data_dir, beam_config): + check_beam_decoding(test_data_dir, beam_config) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "maes", + "maes_num_steps": 2, + "maes_expansion_beta": 1, + "beam_size": 4, + "ngram_lm_alpha": 0.3, + }, + ], + ) + def test_tdt_beam_decoding_with_kenlm(self, test_data_dir, beam_config): + # skipping if kenlm is not installed + pytest.importorskip("kenlm", reason="Skipping test because 'kenlm' is not installed.") + + kenlm_model_path = os.path.join( + test_data_dir, "asr", "kenlm_ngram_lm", "parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa" + ) + beam_config["ngram_lm_model"] = kenlm_model_path + check_beam_decoding(test_data_dir, beam_config) From a9a959cf4d61677954940a765d2059e4835c5916 Mon Sep 17 00:00:00 2001 From: Huiying Date: Wed, 13 Nov 2024 12:06:56 -0800 Subject: [PATCH 17/24] update nemo1->2 conversion according to changes in main (#11253) * update nemo1->2 conversion according to changes in main Signed-off-by: Huiying Li * Apply isort and black reformatting Signed-off-by: HuiyingLi * format fix Signed-off-by: Huiying Li * add docstrings Signed-off-by: Huiying Li --------- Signed-off-by: Huiying Li Signed-off-by: HuiyingLi Co-authored-by: HuiyingLi --- .../convert_nemo1_to_nemo2.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py index 1d69c1aec5eb..12e56e9f1793 100644 --- a/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py +++ b/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py @@ -23,7 +23,8 @@ --output_path=your_output_dir \ --model_id=meta-llama/Meta-Llama-3-8B -b. Convert a model weight directory. The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file. +b. Convert a model weight directory. + The checkpoint should be similar to `model_weights` subdir after extracting the .nemo file. Please also provide tokenizer_library and tokenizer_path when loading from weight directory. python /opt/NeMo/scripts/checkpoint_converters/convert_nemo1_to_nemo2.py \ --input_path=nemotron3-8b-extracted/model_weights \ @@ -52,8 +53,8 @@ from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib -from nemo.lightning.ckpt_utils import ckpt_to_context_subdir, ckpt_to_weights_subdir -from nemo.lightning.io.pl import TrainerContext +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir from nemo.utils import logging MODEL_CONFIG_MAPPING = { @@ -66,22 +67,29 @@ "mistralai/Mixtral-8x22B-v0.1": (llm.MixtralModel, llm.MixtralConfig8x22B), "mistralai/Mistral-7B-v0.1": (llm.MistralModel, llm.MistralConfig7B), "nvidia/nemotron-3-8b-base-4k": (llm.NemotronModel, llm.Nemotron3Config8B), - "nemotron4-22b": (llm.NemotronModel, llm.Nemotron4Config22B), + "nemotron4-22b": (llm.NemotronModel, llm.Nemotron3Config22B), "nemotron4-15b": (llm.NemotronModel, llm.Nemotron4Config15B), "nemotron4-340b": (llm.NemotronModel, llm.Nemotron4Config340B), } def get_args(): + """ + Parse the command line arguments. + """ parser = ArgumentParser( - description="Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. This script may download from Hugging Face, make sure you have access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)" + description="""Script to convert NeMo 1.0 checkpoints to NeMo 2.0 format. + This script may download from Hugging Face, make sure you have + access to gate repo and have logged into Hugging Face (e.g. huggingface-cli login)""" ) parser.add_argument( "--input_path", type=str, default=None, required=True, - help="Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory after untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass in `model_weights` directory.", + help="""Path to NeMo 1.0 checkpoints. Could be .nemo file, or `model_weights` directory a + fter untar the .nemo. Please also provide tokenizer_library and tokenizer_path if you pass + in `model_weights` directory.""", ) parser.add_argument( "--output_path", type=str, default=None, required=True, help="Path to output NeMo 2.0 directory." @@ -94,7 +102,8 @@ def get_args(): type=str, default=None, required=False, - help="Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails", + help="""Path to tokenizer. If not provided, will 1. try instantiate from nemo1 config + 2. pull AutoTokenizer from Hugging Face according to model_id if 1 fails""", ) parser.add_argument( "--tokenizer_library", @@ -108,6 +117,12 @@ def get_args(): def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel: + """ + Get NeMo 2.0 model class from model_id and tokenizer. Use bf16 for NeMo 1.0 ckpts. + + Returns: + llm.GPTModel: NeMo 2.0 model instance + """ if model_id not in MODEL_CONFIG_MAPPING: valid_ids = "\n- ".join([""] + list(MODEL_CONFIG_MAPPING.keys())) @@ -118,6 +133,13 @@ def get_nemo2_model(model_id, tokenizer) -> llm.GPTModel: def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: + """ + Get tokenizer from input .nemo file, or args.tokenizer_path, or Hugging Face. + Only SentencePiece and Hugging Face tokenizers are supported. + + Returns: + AutoTokenizer: tokenizer instance + """ if not input_path.is_dir(): # if .nemo tar with tempfile.TemporaryDirectory() as tmp_dir: # we want to clean up this tmp dir NLPSaveRestoreConnector._unpack_nemo_file(input_path, tmp_dir) @@ -134,7 +156,7 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: tokenizer_lib = args.tokenizer_library or "sentencepiece" if args.tokenizer_library is None: logging.warning( - "You specified tokenizer_path but did not provide tokenizer_library, will default to sentencepiece" + "You specified tokenizer_path but did not provide tokenizer_library using default sentencepiece" ) tokenizer_model = args.tokenizer_path else: # no .nemo config, no tokenizer path specified, grab from HF, reload @@ -148,6 +170,9 @@ def get_tokenizer(input_path: Path, tokenizer_tmp_dir: Path) -> AutoTokenizer: def main() -> None: + """ + Main function to convert NeMo 1.0 checkpoint to NeMo 2.0 format. + """ tokenizer_tmp_dir = Path("/tmp/nemo_tokenizer") tokenizer_tmp_dir.mkdir(parents=True, exist_ok=True) tokenizer = get_tokenizer(Path(args.input_path), tokenizer_tmp_dir) @@ -196,7 +221,7 @@ def skip_fp8_load(x): logging.info(f"Saving checkpoint to {args.output_path}") model_ckpt['state_dict'] = {k.replace('model', 'module', 1): v for k, v in model_ckpt['state_dict'].items()} trainer.model.module.load_state_dict(model_ckpt['state_dict']) - trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path)) + trainer.save_checkpoint(ckpt_to_weights_subdir(args.output_path, is_saving=False)) if getattr(trainer.strategy, "async_save", False): trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True) From 3625d78ad53cd702e09a215dbe7c989a44fbdc61 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 13 Nov 2024 16:50:22 -0500 Subject: [PATCH 18/24] Add llama 3.1 recipes (#11273) * add llama 3.1 recipes Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * fix pylint Signed-off-by: Chen Cui * Fix llama3.1 wrong config in io.json --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Co-authored-by: cuichenx Co-authored-by: Ao Tang --- nemo/collections/llm/gpt/model/llama.py | 7 +- nemo/collections/llm/recipes/__init__.py | 4 + .../llm/recipes/finetune_default.py | 19 +- nemo/collections/llm/recipes/llama31_405b.py | 118 ++++- nemo/collections/llm/recipes/llama31_70b.py | 403 ++++++++++++++++++ nemo/collections/llm/recipes/llama31_8b.py | 385 +++++++++++++++++ nemo/collections/llm/recipes/llama3_70b.py | 7 +- nemo/collections/llm/recipes/llama3_8b.py | 8 +- tests/lightning/test_nemo_run.py | 5 + 9 files changed, 938 insertions(+), 18 deletions(-) create mode 100644 nemo/collections/llm/recipes/llama31_70b.py create mode 100644 nemo/collections/llm/recipes/llama31_8b.py diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 7b235d59ee89..a9d18220bcaf 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -273,7 +273,12 @@ def make_vocab_size_divisible_by(vocab_size): base //= 2 return base - output = LlamaConfig( + if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': + # Apply Llama3.1 customize rope scaling + cls = Llama31Config + else: + cls = LlamaConfig + output = cls( num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 9f53ec88bdc8..8f772e3da5b7 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -30,6 +30,8 @@ llama3_70b, llama3_70b_16k, llama3_70b_64k, + llama31_8b, + llama31_70b, llama31_405b, mamba2_1_3b, mamba2_2_7b, @@ -82,6 +84,8 @@ "llama3_70b", "llama3_70b_16k", "llama3_70b_64k", + "llama31_8b", + "llama31_70b", "llama31_405b", "mamba2_130m", "mamba2_370m", diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index 69266737edc9..a060046a8bdf 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -16,6 +16,7 @@ import nemo_run as run import pytorch_lightning as pl +import torch import nemo.lightning as nl from nemo.collections import llm @@ -82,7 +83,7 @@ def default_finetune_recipe( def default_finetune_trainer( tensor_parallelism=1, pipeline_parallelism=1, - pipeline_parallelism_type=None, + pipeline_parallelism_type=torch.bfloat16, virtual_pipeline_parallelism=None, context_parallelism=1, sequence_parallelism=False, @@ -93,6 +94,19 @@ def default_finetune_trainer( limit_val_batches=None, val_check_interval=30, ): + """ + Create a default fine-tuning trainer for any model. + + This function sets up a template for strategy and trainer. + + Args: + See docstrings of MegatronStrategy and Trainer. + + Returns: + run.Config: Config for a finetuning trainer. + + See usages of this in recipes for further details. + """ strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, @@ -125,7 +139,8 @@ def default_finetune_trainer( def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for + https://huggingface.co/{model_id}. This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index e753c48387c0..31c83713b6e7 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -33,6 +34,7 @@ from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -248,6 +250,9 @@ def finetune_recipe( num_nodes: int = 3, num_gpus_per_node: int = 8, peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, ) -> run.Partial: """ Create a fine-tuning recipe for Llama3.1 405B model. @@ -261,8 +266,11 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,22 +287,116 @@ def finetune_recipe( This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 405B model requires substantial computational resources. """ + if packed_sequence is None: + packed_sequence = performance_mode + + if seq_length is None: + seq_length = 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 12 + elif peft_scheme.lower() == 'lora': + num_nodes = 3 + recipe = default_finetune_recipe( - model(), "meta-llama/Meta-Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node + model(), "meta-llama/Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) - if peft_scheme is None or peft_scheme.lower() == 'none': - assert num_nodes >= 4 recipe.trainer.strategy.tensor_model_parallel_size = 8 - recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.trainer.strategy.pipeline_model_parallel_size = 6 - recipe.trainer.strategy.virtual_pipeline_parallelism = 7 - recipe.data.global_batch_size = 128 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + # Note: limited support. This is not necessarily the most optimized setting + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 6 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py new file mode 100644 index 000000000000..91e4e10c83e6 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -0,0 +1,403 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_70b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 70B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 70B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_70b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config70B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 5, + context_parallelism: int = 2, + sequence_parallelism: bool = True, + num_nodes: int = 4, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 70B model. + + This function sets up the distributed training strategy optimized for the large 70B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_70b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=4, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 70B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_70b + $ nemo llm pretrain --factory "llama31_70b(num_nodes=4, name='my_70b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_70b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 70B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 70B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = None, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 70B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_70b + $ nemo llm finetune --factory "llama31_70b(num_nodes=4, name='my_70b_finetune')" + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_70b_finetune", num_nodes=4) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model + requires substantial computational resources. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 4 + elif peft_scheme.lower() == 'lora': + num_nodes = 1 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.1-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py new file mode 100644 index 000000000000..a4f0082e8535 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -0,0 +1,385 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_8b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 8B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 8B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_8b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config8B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 8B model. + + This function sets up the distributed training strategy optimized for the large 8B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_8b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 8B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_8b + $ nemo llm pretrain --factory "llama31_8b(num_nodes=4, name='my_8b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_8b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 8B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 8B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_8b + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_8b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Meta-Llama-3.1-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index e2156993647d..d43302a0a0ee 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -263,9 +263,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -325,7 +326,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 1030ad8799a1..4f6f6ce17443 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -25,7 +25,6 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -251,9 +250,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -305,7 +305,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/tests/lightning/test_nemo_run.py b/tests/lightning/test_nemo_run.py index f91322116824..1371b9adaa8e 100644 --- a/tests/lightning/test_nemo_run.py +++ b/tests/lightning/test_nemo_run.py @@ -30,7 +30,12 @@ ("llama3_70b", "finetune_recipe", "llama3_70b_finetune"), ("llama3_70b_16k", "pretrain_recipe", "llama3_70b_16k_pretrain"), ("llama3_70b_64k", "pretrain_recipe", "llama3_70b_64k_pretrain"), + ("llama31_8b", "pretrain_recipe", "llama31_8b_pretrain"), + ("llama31_8b", "finetune_recipe", "llama31_8b_finetune"), + ("llama31_70b", "pretrain_recipe", "llama31_70b_pretrain"), + ("llama31_70b", "finetune_recipe", "llama31_70b_finetune"), ("llama31_405b", "pretrain_recipe", "llama31_405b_pretrain"), + ("llama31_405b", "finetune_recipe", "llama31_405b_finetune"), ("mistral_7b", "pretrain_recipe", "mistral_pretrain"), ("mistral_7b", "finetune_recipe", "mistral_finetune"), ("mixtral_8x7b", "pretrain_recipe", "mixtral_8x7b_pretrain"), From 071f8bc088c0390fbe4037f972f03341e9e191cd Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Wed, 13 Nov 2024 16:54:24 -0500 Subject: [PATCH 19/24] Fix Finetune Recipe (#11267) * Fix Starcoder_15 SFT recipe * Fix PP type SFT recipe * Fix PP type SFT recipe * Fix Gemma2b SFT TP=1 * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * Fix more sft recipe * remove pp dtype * remove pp dtype --- nemo/collections/llm/recipes/gemma_2b.py | 2 +- nemo/collections/llm/recipes/starcoder_15b.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index 3e54deb0bc1c..8b2111e9f7c4 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -282,7 +282,7 @@ def finetune_recipe( recipe.data.dataset_kwargs = {'add_bos': True} if peft_scheme is None or peft_scheme.lower() == 'none': - recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.context_parallel_size = 2 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py index d4e76abe897e..cb0ba14df868 100644 --- a/nemo/collections/llm/recipes/starcoder_15b.py +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -300,7 +300,7 @@ def finetune_recipe( """ recipe = default_finetune_recipe(model(), "bigcode/starcoder", dir, name, num_nodes, num_gpus_per_node) if peft_scheme is None or peft_scheme.lower() == 'none': - recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) From 02f093230d41e9265ee6c8f257b4f3ec74595548 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 13 Nov 2024 14:21:40 -0800 Subject: [PATCH 20/24] Configure no restart validation loop in nl.Trainer (#11029) * Configure no restart validation loop in nl.Trainer Signed-off-by: Hemil Desai * fix Signed-off-by: Hemil Desai * Skip validation whenever restarting=True Signed-off-by: Hemil Desai * PR feedback Signed-off-by: Hemil Desai * Apply isort and black reformatting Signed-off-by: hemildesai --------- Signed-off-by: Hemil Desai Signed-off-by: hemildesai Co-authored-by: hemildesai --- nemo/collections/llm/api.py | 10 ++++++++- nemo/lightning/__init__.py | 3 ++- nemo/lightning/pytorch/trainer.py | 37 ++++++++++++++++++++++++++++++- 3 files changed, 47 insertions(+), 3 deletions(-) diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 13f25eb21087..fdceff5d959e 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -25,7 +25,14 @@ from typing_extensions import Annotated import nemo.lightning as nl -from nemo.lightning import AutoResume, NeMoLogger, OptimizerModule, Trainer, io +from nemo.lightning import ( + AutoResume, + NeMoLogger, + OptimizerModule, + Trainer, + configure_no_restart_validation_training_loop, + io, +) from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging @@ -680,6 +687,7 @@ def _setup( tokenizer: Optional[TokenizerType], model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified + configure_no_restart_validation_training_loop(trainer) _log = log or NeMoLogger() if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 2cc720e148d4..91d3b3f936d0 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -33,7 +33,7 @@ from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import FSDPStrategy, MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig -from nemo.lightning.pytorch.trainer import Trainer +from nemo.lightning.pytorch.trainer import Trainer, configure_no_restart_validation_training_loop from nemo.lightning.resume import AutoResume @@ -66,6 +66,7 @@ def _is_slurm_interactive_mode(): "ModelCheckpoint", "OptimizerModule", "Trainer", + "configure_no_restart_validation_training_loop", "get_vocab_size", "teardown", ] diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index 0d71c49bf198..c97c59ef524d 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from copy import deepcopy import fiddle as fdl import pytorch_lightning as pl +from pytorch_lightning.loops import _TrainingEpochLoop +from pytorch_lightning.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric @@ -23,8 +26,40 @@ from nemo.lightning.io.mixin import IOMixin, serialization, track_io -class Trainer(pl.Trainer, IOMixin): +class NoValOnRestartTrainingLoop(_TrainingEpochLoop): + """ + Extend the PTL Epoch loop to skip validation when restarting. + This happens when resuming a checkpoint that has already run validation, but loading restores + the training state before validation has run. + """ + + def _should_check_val_fx(self, data_fetcher) -> bool: + if self.skip_val_on_restart: + return False + return super()._should_check_val_fx(data_fetcher) + + def load_state_dict(self, state_dict: dict, prefix: str = "") -> None: + super().load_state_dict(state_dict, prefix) + + self.skip_val_on_restart = True + + def advance(self, data_fetcher: _DataFetcher) -> None: + super().advance(data_fetcher) + + self.skip_val_on_restart = False + +def configure_no_restart_validation_training_loop(trainer: pl.Trainer) -> None: + if not isinstance(trainer.fit_loop.epoch_loop, _TrainingEpochLoop): + warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) + return + + ## Pass trainer object to avoid trainer getting overwritten as None + loop = NoValOnRestartTrainingLoop(trainer, trainer.min_steps, trainer.max_steps) + trainer.fit_loop.epoch_loop = loop + + +class Trainer(pl.Trainer, IOMixin): def add_io(self, obj): """Recurse to the leaves of a container and add io functionality to non-serializable leaves""" if isinstance(obj, (dict, list)): From af91d28e2680a9f36ce6bc0dd9cf3f809f770da8 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Wed, 13 Nov 2024 22:22:27 -0800 Subject: [PATCH 21/24] Handle _io_unflatten_object when _thread_local.output_dir is not available (#11199) Signed-off-by: Hemil Desai --- nemo/lightning/io/mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 3613444b6330..33b7afdf1e76 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -608,7 +608,9 @@ def _io_flatten_object(instance): def _io_unflatten_object(values, metadata): - assert hasattr(_thread_local, "output_dir") + if not hasattr(_thread_local, "output_dir"): + return fdl.Config.__unflatten__(values, metadata) + output_dir = _thread_local.output_dir if len(values) == 1: From 8b0c31196046731f437f5e8fbb65ea79a15feafc Mon Sep 17 00:00:00 2001 From: Maanu Grover <109391026+maanug-nv@users.noreply.github.com> Date: Thu, 14 Nov 2024 00:33:14 -0800 Subject: [PATCH 22/24] change default ckpt name (#11277) Signed-off-by: Maanu Grover --- nemo/lightning/nemo_logger.py | 2 +- tests/collections/llm/bitexact/mixtral/run.sh | 2 +- tests/collections/llm/megatron_mixtral_pretraining.py | 2 +- tests/lightning/test_state_restoration.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 8b10f9aca50a..a901a3a8842a 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -220,7 +220,7 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None): if callback.dirpath is None: callback.dirpath = Path(log_dir / "checkpoints") if callback.filename is None: - callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}" + callback.filename = f"{self.name}--{{{callback.monitor}:.4f}}-{{epoch}}-{{consumed_samples}}" ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + "-last" def _handle_task_config(self, task_config, log_dir): diff --git a/tests/collections/llm/bitexact/mixtral/run.sh b/tests/collections/llm/bitexact/mixtral/run.sh index 0fe9e331b18a..87bf7c382b99 100644 --- a/tests/collections/llm/bitexact/mixtral/run.sh +++ b/tests/collections/llm/bitexact/mixtral/run.sh @@ -43,4 +43,4 @@ python3 /workspace/tests/collections/llm/bitexact/mixtral/pretrain_mini_mixtral. # Compare outputs python3 /workspace/tests/collections/llm/bitexact/mixtral/compare_ckpts.py \ - "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" + "$NEMO_OUTPUT_PATH/checkpoints/--None=0.0000-epoch=0-consumed_samples=20.0/weights" "$MCORE_OUTPUT_PATH/iter_0000010/" diff --git a/tests/collections/llm/megatron_mixtral_pretraining.py b/tests/collections/llm/megatron_mixtral_pretraining.py index b4c5b960e0a7..4123c7b37987 100644 --- a/tests/collections/llm/megatron_mixtral_pretraining.py +++ b/tests/collections/llm/megatron_mixtral_pretraining.py @@ -158,7 +158,7 @@ def main(args): ) # Confirm checkpoint directory structure - output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0/weights" + output_path = Path(args.experiment_dir) / "checkpoints/--None=0.0000-epoch=0-consumed_samples=8.0/weights" assert output_path.exists(), f"Expected {output_path} to exist" assert output_path.is_dir(), f"Expected {output_path} to be a directory" output_files = ['__0_0.distcp', '__0_1.distcp', 'common.pt', 'metadata.json', '.metadata'] diff --git a/tests/lightning/test_state_restoration.py b/tests/lightning/test_state_restoration.py index 44e0673a1a39..ccc0eed64d56 100644 --- a/tests/lightning/test_state_restoration.py +++ b/tests/lightning/test_state_restoration.py @@ -239,7 +239,7 @@ def run_resume_train(mbs, gbs, num_dev): resume=AutoResume( resume_if_exists=True, resume_ignore_no_checkpoint=False, - resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0/', + resume_from_path=f'{EXP_DIR}default/v1/checkpoints/default--None=0.0000-epoch=0-consumed_samples=20.0/', ), ) trainer._teardown() From bf7cc64913945ef5e4b02b1c526aa99a912a2153 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Date: Thu, 14 Nov 2024 02:45:17 -0800 Subject: [PATCH 23/24] Use MegatronDataSampler in HfDatasetDataModule (#11274) * Use MegatronDataSampler in HfDataset Signed-off-by: Alexandros Koumparoulis * Apply isort and black reformatting Signed-off-by: akoumpa --------- Signed-off-by: Alexandros Koumparoulis Signed-off-by: akoumpa Co-authored-by: akoumpa --- nemo/collections/llm/gpt/data/hf_dataset.py | 35 ++++++++++----------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 7e70a970913e..5c6b71c74797 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -15,6 +15,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from nemo.lightning.pytorch.plugins import MegatronDataSampler class HfDatasetDataModule(pl.LightningDataModule): @@ -24,6 +25,7 @@ def __init__( num_workers=2, pin_memory=True, persistent_workers=True, + seq_length=1024, micro_batch_size=2, global_batch_size=2, pad_token_id=0, @@ -37,6 +39,7 @@ def __init__( self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers + self.seq_length = seq_length self.micro_batch_size = micro_batch_size self.global_batch_size = global_batch_size self.pad_token_id = pad_token_id @@ -58,6 +61,7 @@ def pad_within_micro(batch, pad_token_id): max_len = max(map(len, batch)) return [item + [pad_token_id] * (max_len - len(item)) for item in batch] + keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask'])) return { key: batchify( torch.LongTensor( @@ -67,16 +71,26 @@ def pad_within_micro(batch, pad_token_id): ) ) ) - for key in ['tokens', 'labels'] + for key in keys } + def setup(self, stage: str): + if not self.use_mcore_sampler: + return + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + dataloader_type=self.mcore_dataloader_type, + ) + def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) - dataloader = DataLoader( + return DataLoader( self.dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, @@ -84,20 +98,3 @@ def train_dataloader(self, collate_fn=None): collate_fn=collate_fn, batch_size=self.micro_batch_size, ) - if not self.use_mcore_sampler: - return dataloader - - rank = 0 - world_size = 1 - if torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - - return add_megatron_sampler( - dataloader, - self.micro_batch_size, - self.global_batch_size, - dataloader_type=self.mcore_dataloader_type, - rank=rank, - world_size=world_size, - ) From 062532770dbe790e73637dcd0926d964628cbaa5 Mon Sep 17 00:00:00 2001 From: Dong Hyuk Chang Date: Thu, 14 Nov 2024 10:16:41 -0500 Subject: [PATCH 24/24] Remove opencc upperbound (#10909) Signed-off-by: Dong Hyuk Chang --- requirements/requirements_nlp.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index 16b6c574d2fa..6a86dacbfefb 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -14,7 +14,7 @@ matplotlib>=3.3.2 #megatron_core>0.6.0 # add back once mcore on pypi is compatible again nltk>=3.6.5 numpy<2 # tensorstore has an implicit compiled dependency on numpy<2 -opencc<1.1.7 +opencc pangu prettytable rapidfuzz