diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index 7b235d59ee89..a9d18220bcaf 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -273,7 +273,12 @@ def make_vocab_size_divisible_by(vocab_size): base //= 2 return base - output = LlamaConfig( + if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': + # Apply Llama3.1 customize rope scaling + cls = Llama31Config + else: + cls = LlamaConfig + output = cls( num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 9f53ec88bdc8..8f772e3da5b7 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -30,6 +30,8 @@ llama3_70b, llama3_70b_16k, llama3_70b_64k, + llama31_8b, + llama31_70b, llama31_405b, mamba2_1_3b, mamba2_2_7b, @@ -82,6 +84,8 @@ "llama3_70b", "llama3_70b_16k", "llama3_70b_64k", + "llama31_8b", + "llama31_70b", "llama31_405b", "mamba2_130m", "mamba2_370m", diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index 69266737edc9..a060046a8bdf 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -16,6 +16,7 @@ import nemo_run as run import pytorch_lightning as pl +import torch import nemo.lightning as nl from nemo.collections import llm @@ -82,7 +83,7 @@ def default_finetune_recipe( def default_finetune_trainer( tensor_parallelism=1, pipeline_parallelism=1, - pipeline_parallelism_type=None, + pipeline_parallelism_type=torch.bfloat16, virtual_pipeline_parallelism=None, context_parallelism=1, sequence_parallelism=False, @@ -93,6 +94,19 @@ def default_finetune_trainer( limit_val_batches=None, val_check_interval=30, ): + """ + Create a default fine-tuning trainer for any model. + + This function sets up a template for strategy and trainer. + + Args: + See docstrings of MegatronStrategy and Trainer. + + Returns: + run.Config: Config for a finetuning trainer. + + See usages of this in recipes for further details. + """ strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, @@ -125,7 +139,8 @@ def default_finetune_trainer( def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: """ - Configure automatic resumption from a NeMo checkpoint converted from Huggingface for https://huggingface.co/{model_id}. + Configure automatic resumption from a NeMo checkpoint converted from Huggingface for + https://huggingface.co/{model_id}. This NeMo checkpoint should be converted from Huggingface beforehand, using nemo.collections.llm.import_ckpt. When converting the checkpoint, the NeMo checkpoint will be saved in NEMO_HOME (set to ~/.cache/nemo by default). diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index e753c48387c0..31c83713b6e7 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -24,6 +24,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -33,6 +34,7 @@ from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, ) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback from nemo.utils.exp_manager import TimingCallback @@ -248,6 +250,9 @@ def finetune_recipe( num_nodes: int = 3, num_gpus_per_node: int = 8, peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, ) -> run.Partial: """ Create a fine-tuning recipe for Llama3.1 405B model. @@ -261,8 +266,11 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,22 +287,116 @@ def finetune_recipe( This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 405B model requires substantial computational resources. """ + if packed_sequence is None: + packed_sequence = performance_mode + + if seq_length is None: + seq_length = 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 12 + elif peft_scheme.lower() == 'lora': + num_nodes = 3 + recipe = default_finetune_recipe( - model(), "meta-llama/Meta-Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node + model(), "meta-llama/Llama-3.1-405B", dir, name, num_nodes, num_gpus_per_node, packed_sequence ) - if peft_scheme is None or peft_scheme.lower() == 'none': - assert num_nodes >= 4 recipe.trainer.strategy.tensor_model_parallel_size = 8 - recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 5e-6 elif peft_scheme.lower() == 'lora': recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.trainer.strategy.pipeline_model_parallel_size = 6 - recipe.trainer.strategy.virtual_pipeline_parallelism = 7 - recipe.data.global_batch_size = 128 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + recipe.data.global_batch_size = 6 recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + # Note: limited support. This is not necessarily the most optimized setting + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 14 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 6 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + return recipe diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py new file mode 100644 index 000000000000..91e4e10c83e6 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -0,0 +1,403 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_70b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 70B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 70B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_70b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config70B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 4, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 5, + context_parallelism: int = 2, + sequence_parallelism: bool = True, + num_nodes: int = 4, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 70B model. + + This function sets up the distributed training strategy optimized for the large 70B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_70b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=4, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 70B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_70b + $ nemo llm pretrain --factory "llama31_70b(num_nodes=4, name='my_70b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_70b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 70B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 70B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = None, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 70B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_70b + $ nemo llm finetune --factory "llama31_70b(num_nodes=4, name='my_70b_finetune')" + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_70b_finetune", num_nodes=4) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. Be aware that fine-tuning a 70B model + requires substantial computational resources. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + if num_nodes is None: + if peft_scheme is None or peft_scheme.lower() == 'none': + num_nodes = 4 + elif peft_scheme.lower() == 'lora': + num_nodes = 1 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.1-70B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 16 + recipe.peft.alpha = 32 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.trainer.strategy.tensor_model_parallel_size = 8 + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 4 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=22, + ) + ) + else: + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.trainer.strategy.pipeline_model_parallel_size = 4 + recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + + recipe.trainer.strategy.sequence_parallel = True + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py new file mode 100644 index 000000000000..a4f0082e8535 --- /dev/null +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -0,0 +1,385 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch +from megatron.core.distributed import DistributedDataParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import ( + userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, +) +from nemo.lightning.pytorch.callbacks import GarbageCollectionCallback +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama31_8b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.1 8B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.1 8B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama31_8b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama31Config8B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 2, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.1 8B model. + + This function sets up the distributed training strategy optimized for the large 8B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama31_8b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.1 8B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama31_8b + $ nemo llm pretrain --factory "llama31_8b(num_nodes=4, name='my_8b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama31_8b_pretrain", num_nodes=4) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + if performance_mode: + recipe = pretrain_performance_optimizations(recipe) + + return recipe + + +def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Llama3.1 8B model. + + This method enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + recipe (run.Partial): Base pre-train recipe to which performance optimizations will be added + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + + # 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically + # by MegatronCommOverlapCallback. They are added here for user's knowledge. + # overlap_param_gather_with_optimizer_step- Overlap param all-gather of first bucket with optimizer step. + # align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else + # each PP stage launches independently as needed. + + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=True, + tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192, + defer_embedding_wgrad_compute=True, + wgrad_deferral_limit=50, + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + align_param_gather=True, + ) + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.1 8B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama31_8b + + Python API usage: + >>> recipe = finetune_recipe(name="llama31_8b_finetune", num_nodes=2) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Meta-Llama-3.1-8B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config(LoRA) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.peft.target_modules = ['linear_qkv'] + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + if performance_mode: + recipe = finetune_performance_optimizations(recipe, peft_scheme) + + return recipe + + +def finetune_performance_optimizations( + recipe: run.Partial, + peft_scheme: str, +) -> run.Partial: + """ + Modify the given recipe to optimize settings for performance. + + This method enables performance optimizations that may not be suitable for all use cases. + Intended to build upon the standard fine-tuning recipe. + + Args: + recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added + peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + + Returns: + run.Partial: Partial configuration for performance-optimized fine-tuning. + + Note: + Use this method with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe.trainer.strategy.tensor_model_parallel_size = 1 + + if not hasattr(recipe.trainer, "callbacks"): + recipe.trainer.callbacks = [] + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.plugins.grad_reduce_in_fp32 = False + recipe.trainer.strategy.ddp = run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=False, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ) + recipe.trainer.callbacks.append( + run.Config( + MegatronCommOverlapCallback, + tp_comm_overlap=False, + ) + ) + + recipe.trainer.callbacks.append(run.Config(TimingCallback)) + recipe.trainer.callbacks.append( + run.Config( + GarbageCollectionCallback, + 100, + 100, + ) + ) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index e2156993647d..d43302a0a0ee 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -263,9 +263,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -325,7 +326,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 1030ad8799a1..4f6f6ce17443 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -25,7 +25,6 @@ from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe @@ -251,9 +250,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. performance_mode (bool): If true, enables optimizations for maximum performance. Returns: @@ -305,7 +305,7 @@ def finetune_recipe( recipe.model.config.seq_length = seq_length recipe.data.seq_length = seq_length if packed_sequence: - recipe.data.pad_to_max_length = True + recipe.data.dataset_kwargs = {'pad_to_max_length': True} recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) if performance_mode: diff --git a/tests/lightning/test_nemo_run.py b/tests/lightning/test_nemo_run.py index f91322116824..1371b9adaa8e 100644 --- a/tests/lightning/test_nemo_run.py +++ b/tests/lightning/test_nemo_run.py @@ -30,7 +30,12 @@ ("llama3_70b", "finetune_recipe", "llama3_70b_finetune"), ("llama3_70b_16k", "pretrain_recipe", "llama3_70b_16k_pretrain"), ("llama3_70b_64k", "pretrain_recipe", "llama3_70b_64k_pretrain"), + ("llama31_8b", "pretrain_recipe", "llama31_8b_pretrain"), + ("llama31_8b", "finetune_recipe", "llama31_8b_finetune"), + ("llama31_70b", "pretrain_recipe", "llama31_70b_pretrain"), + ("llama31_70b", "finetune_recipe", "llama31_70b_finetune"), ("llama31_405b", "pretrain_recipe", "llama31_405b_pretrain"), + ("llama31_405b", "finetune_recipe", "llama31_405b_finetune"), ("mistral_7b", "pretrain_recipe", "mistral_pretrain"), ("mistral_7b", "finetune_recipe", "mistral_finetune"), ("mixtral_8x7b", "pretrain_recipe", "mixtral_8x7b_pretrain"),