From 1c265f9193ec017b92c47e880465482ca4189a8e Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Fri, 20 Sep 2024 11:24:57 -0400 Subject: [PATCH 1/7] Force exaclty one monitor tag --- config/base.yaml | 13 +++++-------- milabench/config.py | 9 +++++++++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/config/base.yaml b/config/base.yaml index b06cbea58..d6a1dc5fb 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -27,7 +27,6 @@ _torchvision: --loader: pytorch --data: "{milabench_data}/FakeImageNet" - _torchvision_ddp: inherits: _defaults definition: ../benchmarks/torchvision_ddp @@ -119,7 +118,6 @@ _timm: --dataset: "FakeImageNet" --workers: "auto({n_worker}, 8)" - _accelerate_opt: inherits: _defaults tags: @@ -156,7 +154,6 @@ _accelerate_opt: use_deepspeed: true num_machines: 1 - fp16: inherits: _flops @@ -398,7 +395,6 @@ brax: --num-minibatches: 32 --num-envs: 8192 - _diffusion: inherits: _defaults definition: ../benchmarks/diffusion @@ -551,13 +547,13 @@ _llm: definition: ../benchmarks/llm install_group: torch - llm-lora-single: inherits: _llm tags: - monogpu plan: method: per_gpu + argv: "{milabench_code}/recipes/lora_finetune_single_device.py": true --config: "{milabench_code}/configs/llama3_8B_lora_single_device.yaml" @@ -619,7 +615,6 @@ llm-lora-ddp-nodes: requires_capabilities: - "len(nodes) >= ${num_machines}" - llm-lora-mp-gpus: inherits: _llm tags: @@ -793,7 +788,6 @@ _llava: method: per_gpu tags: - llm - - monogpu argv: --batch_size: 1 --num_workers: "auto({n_worker}, 4)" @@ -801,6 +795,8 @@ _llava: llava-single: inherits: _llava + tags: + - monogpu plan: method: per_gpu argv: @@ -828,7 +824,6 @@ _rlhf: plan: method: per_gpu tags: - - monogpu - rl - rlhf - llm @@ -842,6 +837,8 @@ _rlhf: rlhf-single: inherits: _rlhf + tags: + - monogpu plan: method: per_gpu diff --git a/milabench/config.py b/milabench/config.py index ebc041060..039a85cc4 100644 --- a/milabench/config.py +++ b/milabench/config.py @@ -11,6 +11,8 @@ config_global = contextvars.ContextVar("config", default=None) execution_count = (0, 0) +_MONITOR_TAGS = {"monogpu", "multigpu", "multinode"} + def set_run_count(total_run, total_bench): global execution_count @@ -80,6 +82,13 @@ def finalize_config(name, bench_config): pack = (XPath(bench_config["config_base"]) / pack).resolve() bench_config["definition"] = str(pack) + if not name.startswith("_") and name != "*": + _tags = set(bench_config["tags"]) + _monitor_tags = _tags & _MONITOR_TAGS + assert len(_monitor_tags) == 1, ( + f"Bench {name} should have exactly one monitor tag. Found {_monitor_tags}" + ) + bench_config["tag"] = [bench_config["name"]] bench_config = OmegaConf.to_object(OmegaConf.create(bench_config)) From fd99d8ab642519da5d4ff223c7a2de535f4d41bc Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Mon, 23 Sep 2024 13:48:48 -0400 Subject: [PATCH 2/7] Fix llm with torchtune v0.3 --- benchmarks/llm/configs/llama3_70B_full.yaml | 11 +- benchmarks/llm/configs/llama3_70B_lora.yaml | 10 +- benchmarks/llm/configs/llama3_8B_lora.yaml | 9 +- .../configs/llama3_8B_lora_single_device.yaml | 11 +- .../llm/configs/llama3_8B_qat_full.yaml | 9 +- .../llama3_8B_qlora_single_device.yaml | 14 +- .../llm/recipes/full_finetune_distributed.py | 447 +++++++++++------ .../recipes/full_finetune_single_device.py | 269 ++++++++--- benchmarks/llm/recipes/generate.py | 24 +- .../llm/recipes/lora_finetune_distributed.py | 391 +++++++++------ .../recipes/lora_finetune_single_device.py | 261 ++++++---- .../ppo_full_finetune_single_device.py | 134 +++--- benchmarks/llm/recipes/qat_distributed.py | 448 ++++++++++++------ benchmate/benchmate/monitor.py | 1 - 14 files changed, 1328 insertions(+), 711 deletions(-) diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml index ae5cf2afb..2cfe8ec92 100644 --- a/benchmarks/llm/configs/llama3_70B_full.yaml +++ b/benchmarks/llm/configs/llama3_70B_full.yaml @@ -20,6 +20,7 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + max_seq_len: null # Dataset dataset: @@ -33,7 +34,7 @@ model: safetensors: true checkpointer: - _component_: torchtune.utils.FullModelHFCheckpointer + _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ model-00001-of-00030.safetensors, @@ -85,7 +86,7 @@ optimizer: fused: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 @@ -95,7 +96,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -memory_efficient_fsdp_wrap: True +custom_sharded_layers: ['tok_embeddings', 'output'] fsdp_cpu_offload: True # Reduced precision @@ -103,8 +104,8 @@ dtype: bf16 # Logging metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/alpaca-llama3-finetune +output_dir: /tmp/full-llama3_1-finetune log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml index 7821e174a..9a8f1680a 100644 --- a/benchmarks/llm/configs/llama3_70B_lora.yaml +++ b/benchmarks/llm/configs/llama3_70B_lora.yaml @@ -16,14 +16,16 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 + lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model + max_seq_len: null safetensors: true checkpointer: - _component_: torchtune.utils.FullModelHFCheckpointer + _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ model-00001-of-00030.safetensors, @@ -80,7 +82,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 @@ -88,9 +90,9 @@ max_steps_per_epoch: null gradient_accumulation_steps: 1 # Logging -output_dir: /tmp/lora_finetune_output +output_dir: /tmp/lora-llama3_1-finetune-output metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_8B_lora.yaml b/benchmarks/llm/configs/llama3_8B_lora.yaml index 7bae8d036..f499b712c 100644 --- a/benchmarks/llm/configs/llama3_8B_lora.yaml +++ b/benchmarks/llm/configs/llama3_8B_lora.yaml @@ -21,6 +21,7 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + max_seq_len: null # Model Arguments model: @@ -30,9 +31,10 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 checkpointer: - _component_: torchtune.utils.FullModelMetaCheckpointer + _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -41,6 +43,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -59,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 @@ -69,7 +72,7 @@ gradient_accumulation_steps: 32 # Logging output_dir: /tmp/lora_finetune_output metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml index b341f5afc..f5d8e3efa 100644 --- a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml +++ b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml @@ -24,14 +24,16 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + max_seq_len: null checkpointer: - _component_: torchtune.utils.FullModelMetaCheckpointer + _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -40,6 +42,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -58,7 +61,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 @@ -69,7 +72,7 @@ compile: False # Logging output_dir: /tmp/lora_finetune_output metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False @@ -81,7 +84,7 @@ enable_activation_checkpointing: True # Profiler (disabled) profiler: - _component_: torchtune.utils.setup_torch_profiler + _component_: torchtune.training.setup_torch_profiler enabled: False #Output directory of trace artifacts diff --git a/benchmarks/llm/configs/llama3_8B_qat_full.yaml b/benchmarks/llm/configs/llama3_8B_qat_full.yaml index 23f60f779..c9d99f98a 100644 --- a/benchmarks/llm/configs/llama3_8B_qat_full.yaml +++ b/benchmarks/llm/configs/llama3_8B_qat_full.yaml @@ -17,6 +17,7 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + max_seq_len: null # Dataset dataset: @@ -29,7 +30,7 @@ model: _component_: torchtune.models.llama3_1.llama3_1_8b checkpointer: - _component_: torchtune.utils.FullModelMetaCheckpointer + _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -45,7 +46,7 @@ epochs: 3 # QAT arguments quantizer: - _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer + _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer groupsize: 256 optimizer: @@ -54,7 +55,7 @@ optimizer: foreach: False loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 @@ -70,7 +71,7 @@ dtype: bf16 # Logging metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} output_dir: /tmp/alpaca-llama3-finetune log_every_n_steps: 1 diff --git a/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml b/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml index fc30f458f..1f50aa9d4 100644 --- a/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml +++ b/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml @@ -23,14 +23,16 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 + lora_dropout: 0.0 # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model + max_seq_len: null checkpointer: - _component_: torchtune.utils.FullModelMetaCheckpointer + _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -39,6 +41,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -57,7 +60,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 @@ -68,7 +71,7 @@ compile: False # Logging output_dir: /tmp/qlora_finetune_output/ metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False @@ -80,7 +83,7 @@ enable_activation_checkpointing: True # Profiler (disabled) profiler: - _component_: torchtune.utils.setup_torch_profiler + _component_: torchtune.training.setup_torch_profiler enabled: False #Output directory of trace artifacts @@ -102,3 +105,6 @@ profiler: warmup_steps: 5 active_steps: 2 num_cycles: 1 + +# For colab use True +low_cpu_ram: False diff --git a/benchmarks/llm/recipes/full_finetune_distributed.py b/benchmarks/llm/recipes/full_finetune_distributed.py index 3a51842da..a46ff0a91 100755 --- a/benchmarks/llm/recipes/full_finetune_distributed.py +++ b/benchmarks/llm/recipes/full_finetune_distributed.py @@ -10,32 +10,26 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import init_process_group -from torch.distributed.fsdp import ( - CPUOffload, - FullOptimStateDictConfig, - FullStateDictConfig, - FullyShardedDataParallel as FSDP, - StateDictType, -) +from torch.distributed import destroy_process_group, init_process_group + from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler - -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.utils.activations import apply_selective_activation_checkpointing +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing from tqdm import tqdm - log = utils.get_logger("DEBUG") @@ -45,8 +39,11 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): distributed training and can be run on a single node (1 to 8 GPUs). Features: - - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU - is not supported. + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -94,12 +91,12 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. """ def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( @@ -122,7 +119,7 @@ def __init__(self, cfg: DictConfig) -> None: # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg @@ -131,7 +128,7 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch @@ -157,28 +154,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[utils.SEED_KEY]: + if self.seed != ckpt_dict[training.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" ) ) - self.seed = ckpt_dict[utils.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -194,8 +191,8 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def setup(self, cfg: DictConfig) -> None: """ - Sets up the recipe state correctly. This includes setting recipe attributes based - on the ``resume_from_checkpoint`` flag. + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, sampler, and dataloader. """ if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -203,34 +200,41 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - ckpt_dict = self.load_checkpoint(cfg.checkpointer) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - # ``_setup_model`` handles initialization and loading the state dict. This method - # should be called before ``_setup_optimizer`` since transforming the optimizer - # state dict requires the model + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - model_state_dict=ckpt_dict[utils.MODEL_KEY], + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), ) - self._tokenizer = config.instantiate(cfg.tokenizer) - # _setup_optimizer should take in ckpt_dict only if training is resumed from - # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=ckpt_dict[utils.OPT_KEY] + opt_state_dict=checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + + if self._is_rank_zero: + log.info("Loss is initialized.") + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized self._sampler, self._dataloader = self._setup_data( @@ -256,49 +260,109 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - memory_efficient_fsdp_wrap: bool, + custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, + reshard_after_forward: bool, model_state_dict: Dict[str, Any], ac_mode: Optional[str] = None, ac_option: Optional[int] = None, ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we load the model on CPU with the right - dtype. To ensure that we don't instantiate ``world_size`` number of models, - we initialize on meta_device for all ranks other than rank 0. - b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the - model weights from checkpoint. - c. While wrapping the model with FSDP, we set ``sync_module_states`` - to TRUE and broadcast module params and buffers from rank 0. - d. The ``device_id`` param ensures that the FSDP initialization happens on - the correct device. + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` """ - if self._is_rank_zero: - log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") - init_start = time.perf_counter() - - with utils.set_default_dtype(self._dtype): - model = config.instantiate(cfg_model) + if self._is_rank_zero: log.info( - f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." ) + init_start = time.perf_counter() - # Load both the model weights. This should happen only on Rank 0 - model.load_state_dict(model_state_dict) - - else: - # For non-zero ranks, load the model on meta device - with utils.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) - - if self._dtype == torch.bfloat16: - model = model.to(torch.bfloat16) + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls @@ -306,9 +370,6 @@ def _setup_model( # ac_mode and ac_option together control selective AC. This is only enabled # when these are set AND ``enable_activation_checkpointing`` is set to False # We'll clean this up as soon as testing of AC is complete - ac_mode = ac_mode - ac_option = ac_option - if (not enable_activation_checkpointing) and (ac_mode is not None): apply_selective_activation_checkpointing( model, @@ -316,43 +377,68 @@ def _setup_model( ac_option, ) - # Wrap the model with FSDP. This will ensure that the model is sharded - # across all available GPUs. - model = FSDP( - module=model, - auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( - memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, - modules_to_wrap={modules.TransformerDecoderLayer}, - ), - cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), - sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, - device_id=self._device, - # this recipe does not currently support mixed precision training - mixed_precision=None, - # Ensure we broadcast params and buffers from rank 0 - sync_module_states=True, - # Initialize empty modules on all non-zero ranks - param_init_fn=( - lambda module: module.to_empty( - device=torch.device("cuda"), recurse=False - ) - if not self._is_rank_zero - else None - ), - ) - - # Ensure no params and buffers are on meta device - utils.validate_no_params_on_meta_device(model) - # original activation checkpointing (full) - flip the condition above if enable_activation_checkpointing and ac_mode is None: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + # For FSDP sharding, we can condition on either the module or its name + # Shard conditions should be callables taking name (relative to model root) + # and the module itself and returning a bool on whether to shard the given module + fsdp_shard_conditions = [] + + # Shard transformer decoder layers (or AC-wrapped versions) + # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) + # But directly using the name is more concise + def _is_layer_fqn(s: str) -> bool: + """ + Return True for layers.i and False for all other module names + Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot + """ + s_list = s.split(".") + return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) + + fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] + + # If wrapping any layers separately, we can add another shard condition + # A layer will be sharded if any of the fsdp_shard_conditions are met + if custom_sharded_layers: + fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] + + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + if self._is_rank_zero: - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -362,17 +448,13 @@ def _setup_model( def _setup_optimizer( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: - """ - Set up the optimizer. This method also handles transforing the state dict - for FSDP. - """ optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - opt_state_dict = FSDP.optim_state_dict_to_load( - self._model, optimizer, opt_state_dict + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, ) - optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: log.info("Optimizer is initialized.") @@ -389,7 +471,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = utils.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -403,23 +485,21 @@ def _setup_data( packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( - ds, - num_replicas=world_size, - rank=rank, - shuffle=shuffle, - seed=0, + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) dataloader = DataLoader( dataset=ds, batch_size=batch_size, sampler=sampler, collate_fn=partial( - utils.padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else None, + else partial( + padded_collate_packed, + ), ) if self._is_rank_zero: @@ -427,57 +507,74 @@ def _setup_data( return sampler, dataloader - def save_checkpoint(self, epoch: int) -> None: + def save_checkpoint( + self, + epoch: int, + ) -> None: """ - Save state dict to file. The recipe save_checkpoint method is responsible for - correctly creating the checkpoint dict and passing to the checkpointer. + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. """ + # final dict passed onto the checkpointer checkpoint_dict = {} + intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - with FSDP.state_dict_type( + cpu_state_dict = training.get_full_model_state_dict( self._model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - cpu_state_dict = self._model.state_dict() - opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + self._is_rank_zero, + device=self._device, + ) + + if intermediate_checkpoint: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) - # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: checkpoint_dict.update( { - utils.OPT_KEY: opt_state_dict, - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), + intermediate_checkpoint=intermediate_checkpoint, ) def train(self) -> None: """ - The core training loop. Supports training on subsets of the dataset using the - ``max_steps_per_epoch``. + The core training loop. """ # clean up before training begins - utils.cleanup_before_training() + training.cleanup_before_training() - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -487,6 +584,7 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 + self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -503,6 +601,15 @@ def train(self) -> None: ): break + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -519,13 +626,25 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) - running_loss += loss / self._gradient_accumulation_steps + + # free logits otherwise it peaks backward memory + del logits + + loss = loss / self._gradient_accumulation_steps + running_loss += loss loss.backward() # Step with optimizer @@ -540,7 +659,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -555,7 +674,9 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update(utils.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -566,18 +687,38 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) + self._profiler.stop() + def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() - torch.distributed.destroy_process_group() + destroy_process_group() + def log_loss(self, loss): pass -def prepare_voir(recipe): +def prepare_voir(recipe:FullFinetuneRecipeDistributed): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor @@ -602,7 +743,6 @@ def on_loss(loss): return observer, bench_monitor - @config.parse def recipe_main(cfg: DictConfig) -> None: """ @@ -612,22 +752,22 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not utils.is_distributed(): + if not training.is_distributed(): raise RuntimeError( "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU - utils.set_torch_num_threads() + training.set_torch_num_threads() config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) recipe = FullFinetuneRecipeDistributed(cfg=cfg) recipe.setup(cfg=cfg) + from voir.phase import StopProgram try: _, monitor = prepare_voir(recipe) @@ -635,6 +775,7 @@ def recipe_main(cfg: DictConfig) -> None: recipe.train() except StopProgram: print("early stopping") + recipe.cleanup() diff --git a/benchmarks/llm/recipes/full_finetune_single_device.py b/benchmarks/llm/recipes/full_finetune_single_device.py index 98322579f..f4d0df7cf 100755 --- a/benchmarks/llm/recipes/full_finetune_single_device.py +++ b/benchmarks/llm/recipes/full_finetune_single_device.py @@ -6,11 +6,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from warnings import warn import torch @@ -20,9 +19,11 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -84,6 +85,10 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -98,7 +103,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor # enabled necessary features such as gradient scaling. if self._dtype == torch.float16: @@ -126,11 +131,12 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -152,28 +158,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[utils.SEED_KEY]: + if self.seed != ckpt_dict[training.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" ) ) - self.seed = ckpt_dict[utils.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -202,12 +208,12 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model - self._model_compile = cfg.compile + self._compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=self._model_compile, - model_state_dict=ckpt_dict[utils.MODEL_KEY], + compile_model=self._compile, + model_state_dict=ckpt_dict[training.MODEL_KEY], ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -218,11 +224,20 @@ def setup(self, cfg: DictConfig) -> None: cfg_optimizer=cfg.optimizer, optimizer_in_bwd=cfg.optimizer_in_bwd, opt_state_dict=( - ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ckpt_dict[training.OPT_KEY] if self._resume_from_checkpoint else None ), ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + + if self._compile: + training.compile_loss(self._loss_fn) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be @@ -250,6 +265,82 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_model( self, cfg_model: DictConfig, @@ -260,28 +351,28 @@ def _setup_model( """ Set up the model including enabling activation checkpointing. """ - with utils.set_default_dtype(self._dtype), self._device: + with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) + if compile_model: + training.compile_model(model) + if enable_activation_checkpointing: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. - utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + training.validate_expected_param_dtype( + model.named_parameters(), dtype=self._dtype + ) log.info(f"Model is initialized with precision {self._dtype}.") - # Compile model, if enabled. - if compile_model: - log.info("Compiling model with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - model.compile(backend=backend) if self._device.type == "cuda": - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) return model @@ -301,9 +392,11 @@ def _setup_optimizer( for p in self._model.parameters() } # Register optimizer step hooks on the model to run optimizer in backward. - utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict) + training.register_optim_in_bwd_hooks( + model=self._model, optim_dict=optim_dict + ) # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( model=self._model, optim_dict=optim_dict ) # Load optimizer states. If optimizer states are being restored in an optimizer in backward @@ -340,13 +433,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( @@ -361,12 +454,14 @@ def _setup_data( batch_size=batch_size, sampler=sampler, collate_fn=partial( - utils.padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else None, + else partial( + padded_collate_packed, + ), ) log.info("Dataset and Sampler are initialized.") @@ -378,33 +473,60 @@ def save_checkpoint(self, epoch: int) -> None: Save state dict to file. The recipe save_checkpoint method is responsible for correctly creating the checkpoint dict and passing to the checkpointer. """ - ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()} + ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} # if training is in-progress, checkpoint the optimizer state as well if epoch + 1 < self.total_epochs: ckpt_dict.update( { - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) if not self._optimizer_in_bwd: - ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict() else: - ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, intermediate_checkpoint=(epoch + 1 < self.total_epochs), ) + def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits + + return loss + def train(self) -> None: """ The core training loop. Supports training on subsets of the dataset using the ``max_steps_per_epoch``. """ - if self._model_compile: + if self._compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." ) @@ -417,6 +539,7 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 + self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs @@ -432,35 +555,29 @@ def train(self) -> None: ): break - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - tokens = tokens.to(self._device) - num_tokens += tokens.numel() - labels = labels.to(self._device) - mask = mask.to(self._device) if mask is not None else None - input_pos = ( - input_pos.to(self._device) if input_pos is not None else None - ) + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() - logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) + batch = {k: v.to(self._device) for k, v in batch.items()} + num_tokens += batch["tokens"].numel() + loss = self._loss_step(batch) loss = loss / self._gradient_accumulation_steps running_loss += loss loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) if not self._optimizer_in_bwd: self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -470,7 +587,7 @@ def train(self) -> None: loss_to_log = running_loss.item() pbar.update(1) pbar.set_description( - f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -488,7 +605,11 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._device.type == "cuda" and self._log_peak_memory_stats: - log_dict.update(utils.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -499,9 +620,27 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + self._profiler.step() + self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) + self._profiler.stop() + def cleanup(self) -> None: self._metric_logger.close() diff --git a/benchmarks/llm/recipes/generate.py b/benchmarks/llm/recipes/generate.py index 883a75444..7334d81b0 100755 --- a/benchmarks/llm/recipes/generate.py +++ b/benchmarks/llm/recipes/generate.py @@ -14,7 +14,7 @@ from omegaconf import DictConfig from torch import nn -from torchtune import config, utils +from torchtune import config, generation, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message @@ -38,11 +38,11 @@ class InferenceRecipe: def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(dtype=cfg.dtype) + self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) self._quantizer = config.instantiate(cfg.quantizer) - self._quantization_mode = utils.get_quantizer_mode(self._quantizer) + self._quantization_mode = training.get_quantizer_mode(self._quantizer) - utils.set_seed(seed=cfg.seed) + training.set_seed(seed=cfg.seed) def setup(self, cfg: DictConfig) -> None: checkpointer = config.instantiate(cfg.checkpointer) @@ -56,7 +56,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( model_cfg=cfg.model, - model_state_dict=ckpt_dict[utils.MODEL_KEY], + model_state_dict=ckpt_dict[training.MODEL_KEY], enable_kv_cache=cfg.enable_kv_cache, ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -67,7 +67,7 @@ def _setup_model( model_state_dict: Dict[str, Any], enable_kv_cache: bool = True, ) -> nn.Module: - with utils.set_default_dtype(self._dtype), self._device: + with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(model_cfg) if self._quantization_mode is not None: @@ -77,7 +77,9 @@ def _setup_model( model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. - utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) + training.validate_expected_param_dtype( + model.named_parameters(), dtype=self._dtype + ) logger.info(f"Model is initialized with precision {self._dtype}.") # Ensure the cache is setup on the right device @@ -147,31 +149,29 @@ def generate(self, cfg: DictConfig): if self._quantization_mode is not None: logger.info("Starting compilation to improve generation performance ...") custom_generate_next_token = torch.compile( - utils.generate_next_token, mode="max-autotune", fullgraph=True + generation.generate_next_token, mode="max-autotune", fullgraph=True ) t0 = time.perf_counter() - _ = utils.generate( + _ = generation.generate( model=self._model, prompt=prompt, max_generated_tokens=2, temperature=cfg.temperature, top_k=cfg.top_k, stop_tokens=self._tokenizer.stop_tokens, - pad_id=self._tokenizer.pad_id, custom_generate_next_token=custom_generate_next_token, ) t = time.perf_counter() - t0 logger.info(f"Warmup run for quantized model takes: {t:.02f} sec") t0 = time.perf_counter() - generated_tokens = utils.generate( + generated_tokens = generation.generate( model=self._model, prompt=prompt, max_generated_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k, stop_tokens=self._tokenizer.stop_tokens, - pad_id=self._tokenizer.pad_id, custom_generate_next_token=custom_generate_next_token, ) t = time.perf_counter() - t0 diff --git a/benchmarks/llm/recipes/lora_finetune_distributed.py b/benchmarks/llm/recipes/lora_finetune_distributed.py index 18b736fbf..fdea3871c 100755 --- a/benchmarks/llm/recipes/lora_finetune_distributed.py +++ b/benchmarks/llm/recipes/lora_finetune_distributed.py @@ -6,7 +6,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -19,25 +18,24 @@ from torch import nn from torch.distributed import destroy_process_group, init_process_group -from torch.distributed.fsdp import ( - FullOptimStateDictConfig, - FullStateDictConfig, - FullyShardedDataParallel as FSDP, - StateDictType, -) + from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset -from torchtune.modules.peft.peft_utils import ( +from torchtune.modules.peft import ( + DoRALinear, get_adapter_params, get_lora_module_names, get_merged_lora_ckpt, + load_dora_magnitudes, + LoRALinear, set_trainable_params, - validate_state_dict_for_lora, + validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.utils import DummyProfiler, PROFILER_KEY +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -50,8 +48,11 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): distributed training and can be run on a single node (1 to 8 GPUs). Features: - - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not - supported. + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -109,14 +110,14 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -132,12 +133,13 @@ def __init__(self, cfg: DictConfig) -> None: # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps @@ -157,7 +159,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: # and recipe state to be present. The keys should match up with what ``save_checkpoint`` # used to create these intermediate checkpoints if self._resume_from_checkpoint: - if utils.ADAPTER_KEY not in checkpoint_dict: + if training.ADAPTER_KEY not in checkpoint_dict: raise ValueError( "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." ) @@ -171,28 +173,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[utils.SEED_KEY]: + if self.seed != ckpt_dict[training.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" ) ) - self.seed = ckpt_dict[utils.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -218,13 +220,16 @@ def setup(self, cfg: DictConfig) -> None: self._metric_logger.log_config(cfg) checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( - checkpoint_dict[utils.ADAPTER_KEY] + checkpoint_dict[training.ADAPTER_KEY] if self._resume_from_checkpoint else None ), @@ -233,13 +238,25 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[utils.OPT_KEY] - if self._resume_from_checkpoint - else None, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + if self._compile: + training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._is_rank_zero: + log.info("Loss is initialized.") + # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( @@ -277,14 +294,20 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( - self, cfg_profiler: DictConfig + self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler Args: - cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. Returns: profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods @@ -323,38 +346,42 @@ def _setup_profiler( # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" else: assert ( cfg_profiler.get("_component_") - == "torchtune.utils.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" profiler, profiler_cfg = config.instantiate(cfg_profiler) if self._is_rank_zero: log.info(f" Profiler config after instantiation: {profiler_cfg}") + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + return profiler def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we load the model on CPU with the right - dtype. To ensure that we don't instantiate ``world_size`` number of models, - we initialize on meta_device for all ranks other than rank 0. - b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the - model weights from checkpoint. - c. While wrapping the model with FSDP, we set ``sync_module_states`` - to TRUE and broadcast module params and buffers from rank 0. - d. The ``device_id`` param ensures that the FSDP initialization happens on - the correct device. + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` """ self._lora_rank = cfg_model.lora_rank @@ -364,87 +391,110 @@ def _setup_model( self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) if self._is_rank_zero: - log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) init_start = time.perf_counter() - with utils.set_default_dtype(self._dtype): - model = config.instantiate(cfg_model) + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) - log.info( - f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # The model contains LoRA params which won't have any matching keys in - # the state dict. As a result, we need to load with strict=False. - # Before loading the state dict, ensure the state dict keys for the base - # model and adapters (if available) match the keys in the full LoRA model - # This is a good sanity check to prevent silent errors - validate_state_dict_for_lora( - lora_attn_modules=cfg_model.lora_attn_modules, - apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, - apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), - full_model_state_dict_keys=model.state_dict().keys(), - lora_state_dict_keys=( - lora_weights_state_dict.keys() - if lora_weights_state_dict is not None - else None - ), - base_model_state_dict_keys=base_model_state_dict.keys(), + # For FSDP sharding, we can condition on either the module or its name + # Shard conditions should be callables taking name (relative to model root) + # and the module itself and returning a bool on whether to shard the given module + + # Shard transformer decoder layers (or AC-wrapped versions) + # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) + # But directly using the name is more concise + def _is_layer_name(name: str, module: nn.Module) -> bool: + """ + Return True for layers.i and False for all other module names + Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot + """ + name_list = name.split(".") + return ( + len(name_list) == 2 + and name_list[0] == "layers" + and str.isdigit(name_list[1]) ) - # Load both the base model weights and (if available) the adapter weights. Both - # of this should happen only on Rank 0 - model.load_state_dict(base_model_state_dict, strict=False) - if lora_weights_state_dict: - model.load_state_dict(lora_weights_state_dict, strict=False) + training.shard_model( + model=model, + shard_conditions=[_is_layer_name], + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + if lora_weights_state_dict: + lora_missing, lora_unexpected = training.load_from_full_model_state_dict( + model, + lora_weights_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) else: - # For non-zero ranks, load the model on meta device - with utils.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) - - if self._dtype == torch.bfloat16: - model = model.to(torch.bfloat16) + lora_missing, lora_unexpected = None, None - # LoRA hyper-params needed for merging weights while saving checkpoints - self._lora_rank = cfg_model.lora_rank - self._lora_alpha = cfg_model.lora_alpha - - # Note: this needs to be set before wrapping with FSDP - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) - - model = FSDP( - module=model, - auto_wrap_policy=utils.lora_fsdp_wrap_policy( - modules_to_wrap={modules.TransformerDecoderLayer} - ), - sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, - device_id=self._device, - # this recipe does not currently support mixed precision training - mixed_precision=None, - # Ensure we broadcast params and buffers from rank 0 - sync_module_states=True, - # Initialize empty modules on all non-zero ranks - param_init_fn=( - lambda module: module.to_empty( - device=torch.device("cuda"), recurse=False - ) - if not self._is_rank_zero - else None - ), + # Initialize LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if ( + isinstance(m, LoRALinear) or isinstance(m, DoRALinear) + ) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + base_missing, base_unexpected = training.load_from_full_model_state_dict( + model, + base_model_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + is_dora = False + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + is_dora = True + m.initialize_dora_magnitude() + if is_dora: + load_dora_magnitudes(model) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, ) - # Ensure no params and buffers are on meta device - utils.validate_no_params_on_meta_device(model) + training.validate_no_params_on_meta_device(model) - if enable_activation_checkpointing: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} - ) if self._is_rank_zero: - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -456,15 +506,14 @@ def _setup_optimizer( ) -> Optimizer: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: - # Note: technically we should check _contains_fsdp for - # just the state dict of the adapter cfg, but should be equivalent - opt_state_dict = FSDP.optim_state_dict_to_load( - self._model, optimizer, opt_state_dict + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, ) - optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: - log.info("Optimizer and loss are initialized.") + log.info("Optimizer is initialized.") return optimizer def _setup_lr_scheduler( @@ -494,7 +543,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = utils.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -516,12 +565,14 @@ def _setup_data( batch_size=batch_size, sampler=sampler, collate_fn=partial( - utils.padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else None, + else partial( + padded_collate_packed, + ), ) if self._is_rank_zero: @@ -539,6 +590,7 @@ def save_checkpoint( - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights Checkpointer will save the merged weights, adapter weights and recipe state in different checkpoint files. To correctly resume from training, the adapter weights @@ -550,17 +602,20 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - with FSDP.state_dict_type( + cpu_state_dict = training.get_full_model_state_dict( self._model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - cpu_state_dict = self._model.state_dict() - if intermediate_checkpoint: - opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) - else: - opt_state_dict = None + self._is_rank_zero, + device=self._device, + ) + + if intermediate_checkpoint: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file @@ -572,7 +627,7 @@ def save_checkpoint( adapter_state_dict = { k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) } - checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( @@ -580,18 +635,18 @@ def save_checkpoint( rank=self._lora_rank, alpha=self._lora_alpha, ) - checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. if intermediate_checkpoint: checkpoint_dict.update( { - utils.OPT_KEY: opt_state_dict, - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) @@ -605,12 +660,12 @@ def save_checkpoint( ), "peft_type": "LORA", } - checkpoint_dict.update({utils.ADAPTER_CONFIG: adapter_config}) - + checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) def train(self) -> None: @@ -618,9 +673,9 @@ def train(self) -> None: The core training loop. """ # clean up before training begins - utils.cleanup_before_training() + training.cleanup_before_training() - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -647,6 +702,15 @@ def train(self) -> None: ): break + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -661,14 +725,21 @@ def train(self) -> None: input_pos = ( input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory del logits @@ -689,7 +760,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -704,7 +775,9 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update(utils.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -715,6 +788,18 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + # Step profiler # Note that this is called within gradient accumulation block, hence # will include multiple forward / backward passes if gradient accumulation > 1 @@ -733,7 +818,8 @@ def cleanup(self) -> None: def log_loss(self, loss): pass -def prepare_voir(recipe): + +def prepare_voir(recipe:LoRAFinetuneRecipeDistributed): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor @@ -767,12 +853,15 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not utils.is_distributed(): + if not training.is_distributed(): raise RuntimeError( "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) diff --git a/benchmarks/llm/recipes/lora_finetune_single_device.py b/benchmarks/llm/recipes/lora_finetune_single_device.py index cf5256ead..f08793f52 100755 --- a/benchmarks/llm/recipes/lora_finetune_single_device.py +++ b/benchmarks/llm/recipes/lora_finetune_single_device.py @@ -6,7 +6,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -15,22 +14,25 @@ from warnings import warn import torch +import torchtune.modules.common_utils as common_utils from omegaconf import DictConfig, ListConfig from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset -from torchtune.modules.peft.peft_utils import ( +from torchtune.modules.peft import ( get_adapter_params, get_lora_module_names, get_merged_lora_ckpt, + load_dora_magnitudes, set_trainable_params, validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.utils import DummyProfiler, PROFILER_KEY +from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -88,6 +90,10 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -104,7 +110,7 @@ def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) # Reduced precision logic - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) # fp16 precision is explicitly disabled as it is not supported in this # recipe (for example, no gradient scaling). if self._dtype == torch.float16: @@ -125,14 +131,15 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -147,7 +154,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: checkpoint_dict = self._checkpointer.load_checkpoint() if self._resume_from_checkpoint: - if utils.ADAPTER_KEY not in checkpoint_dict: + if training.ADAPTER_KEY not in checkpoint_dict: raise ValueError( "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." ) @@ -161,28 +168,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict.get(utils.EPOCHS_KEY, 0) + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict.get(utils.SEED_KEY, self.seed): + if self.seed != ckpt_dict[training.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict.get(utils.SEED_KEY, 0)}" + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" ) ) - self.seed = ckpt_dict.get(utils.SEED_KEY, self.seed) - if self.max_steps_per_epoch != ckpt_dict.get(utils.MAX_STEPS_KEY, self.max_steps_per_epoch): + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict.get(utils.MAX_STEPS_KEY, self.max_steps_per_epoch) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict.get(utils.TOTAL_EPOCHS_KEY, self.total_epochs): + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -206,16 +213,21 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - self._model_compile = cfg.compile + self._compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + # hack to toggle to the low cpu ram version of the reparametrize_as_dtype + # hook based on the config. + common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) + + # set up model self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=cfg.compile, - base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( - checkpoint_dict[utils.ADAPTER_KEY] + checkpoint_dict[training.ADAPTER_KEY] if self._resume_from_checkpoint else None ), @@ -227,11 +239,21 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, opt_state_dict=( - checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None ), ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + if self._compile: + self._loss_fn = training.compile_loss(self._loss_fn) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + log.info("Loss is initialized.") # Dataloader depends on the tokenizer and loss_fn and should be @@ -271,14 +293,20 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( - self, cfg_profiler: DictConfig + self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler Args: - cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. Returns: profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods @@ -318,17 +346,23 @@ def _setup_profiler( # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" else: assert ( cfg_profiler.get("_component_") - == "torchtune.utils.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" profiler, profiler_cfg = config.instantiate(cfg_profiler) log.info(f" Profiler config after instantiation: {profiler_cfg}") + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + return profiler def _setup_model( @@ -339,7 +373,7 @@ def _setup_model( base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: - with utils.set_default_dtype(self._dtype), self._device: + with training.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) self._lora_rank = cfg_model.lora_rank @@ -348,23 +382,30 @@ def _setup_model( self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) self.adapter_params = get_adapter_params(model) + self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) set_trainable_params(model, self.adapter_params) + if compile_model: + training.compile_model(model) + if enable_activation_checkpointing: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) + # This is for any adapters that need to be initialized after base weights + # have been loaded (e.g. DoRA). + if self._is_dora: + load_dora_magnitudes(model) if lora_weights_state_dict: lora_missing, lora_unexpected = model.load_state_dict( lora_weights_state_dict, strict=False ) else: lora_missing, lora_unexpected = None, None - validate_missing_and_unexpected_for_lora( lora_attn_modules=self._lora_attn_modules, apply_lora_to_mlp=self._apply_lora_to_mlp, @@ -376,21 +417,16 @@ def _setup_model( ) # Validate model adapter params were loaded in with the expected dtype # TODO (rohan-varma): Further validation to ensure the appropriate base params - # are NF4 vs bf16 based on the quantization config. - utils.validate_expected_param_dtype( + training.validate_expected_param_dtype( self.adapter_params.items(), dtype=self._dtype ) log.info(f"Model is initialized with precision {self._dtype}.") - # Compile model, if enabled. - if compile_model: - log.info("Compiling model with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - model.compile(backend=backend) + if self._device.type == "cuda": - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) return model def _setup_optimizer( @@ -432,13 +468,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + config.instantiate(single_cfg_dataset, self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + ds = config.instantiate(cfg_dataset, self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( @@ -453,12 +489,14 @@ def _setup_data( sampler=sampler, batch_size=batch_size, collate_fn=partial( - utils.padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else None, + else partial( + padded_collate_packed, + ), ) log.info("Dataset and Sampler are initialized.") @@ -472,41 +510,45 @@ def save_checkpoint(self, epoch: int) -> None: - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. """ ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: + if intermediate_checkpoint: ckpt_dict.update( { - utils.OPT_KEY: self._optimizer.state_dict(), - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + training.OPT_KEY: self._optimizer.state_dict(), + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) # Move to CPU to avoid a copy on GPU state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + # Construct the adapter weights + # Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice + # Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in state_dict.items() if adapter_key_filter(k) + } + # Construct the full state dict with LoRA weights merged into base LLM weights merged_state_dict = get_merged_lora_ckpt( state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) - ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) + ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) - # Construct the adapter weights - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) - } - ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) adapter_config = { "r": self._lora_rank, "lora_alpha": self._lora_alpha, @@ -517,19 +559,51 @@ def save_checkpoint(self, epoch: int) -> None: ), "peft_type": "LORA", } - ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) + def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + # run model + logits = self._model(tokens, mask=mask, input_pos=input_pos) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + + return loss + def train(self) -> None: """ The core training loop. """ - if self._model_compile: + if self._compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." ) @@ -555,33 +629,29 @@ def train(self) -> None: ): break - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - tokens = tokens.to(self._device) - num_tokens += tokens.numel() - labels = labels.to(self._device) - mask = mask.to(self._device) if mask is not None else None - input_pos = ( - input_pos.to(self._device) if input_pos is not None else None - ) + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() - logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) - # Compute loss - loss = self._loss_fn(logits, labels) / self._gradient_accumulation_steps + batch = {k: v.to(self._device) for k, v in batch.items()} + num_tokens += batch["tokens"].numel() + + loss = self._loss_step(batch) + loss = loss / self._gradient_accumulation_steps running_loss += loss loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -592,7 +662,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -608,8 +678,10 @@ def train(self) -> None: and self._log_peak_memory_stats ): log_dict.update( - utils.get_memory_stats(device=self._device) + training.get_memory_stats(device=self._device) ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -620,13 +692,31 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + # Step the profiler # Note we are stepping each batch, which might not include optimizer step in the trace # if the schedule cycle doesn't align with gradient accumulation. prof.step() self.epochs_run += 1 + start_save_checkpoint = time.perf_counter() + log.info("Starting checkpoint save...") self.save_checkpoint(epoch=curr_epoch) + log.info( + "Checkpoint saved in {:.2f} seconds.".format( + time.perf_counter() - start_save_checkpoint + ) + ) def cleanup(self) -> None: self._metric_logger.close() @@ -635,8 +725,7 @@ def log_loss(self, loss): pass - -def prepare_voir(recipe): +def prepare_voir(recipe:LoRAFinetuneRecipeSingleDevice): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor diff --git a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py index 8ee77c06a..bdd63e8cd 100644 --- a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py +++ b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py @@ -17,7 +17,8 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate from torchtune.datasets import ConcatDataset from torchtune.modules import rlhf from torchtune.modules.rlhf import PPOStats, Trajectory @@ -106,7 +107,7 @@ class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor # enabled necessary features such as gradient scaling. @@ -122,7 +123,7 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) # manually setting up a generator for the recipe self._rng = torch.Generator(self._device).manual_seed(self.seed) self._total_steps = 0 @@ -177,15 +178,15 @@ def setup(self, cfg: DictConfig) -> None: self._value_model, self._reward_model, self._ref_policy_model, - ) = self._setup_model( + ) = self._setup_models( cfg_model=cfg.policy_model, cfg_reward_value_model=cfg.reward_and_value_model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=self._model_compile, - policy_state_dict=policy_model_checkpoint_dict[utils.MODEL_KEY], - ref_policy_state_dict=ref_policy_state_dict[utils.MODEL_KEY], - value_model_state_dict=value_model_checkpoint_dict[utils.MODEL_KEY], - reward_model_state_dict=reward_model_state_dict[utils.MODEL_KEY], + policy_state_dict=policy_model_checkpoint_dict[training.MODEL_KEY], + ref_policy_state_dict=ref_policy_state_dict[training.MODEL_KEY], + value_model_state_dict=value_model_checkpoint_dict[training.MODEL_KEY], + reward_model_state_dict=reward_model_state_dict[training.MODEL_KEY], ) # setup tokenizer @@ -198,7 +199,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_optimizer=cfg.optimizer, optimizer_in_bwd=cfg.optimizer_in_bwd, opt_state_dict=( - policy_model_checkpoint_dict[utils.OPT_KEY] + policy_model_checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint else None ), @@ -348,7 +349,10 @@ def _setup_checkpointers( value_cfg: DictConfig, reward_cfg: DictConfig, ) -> Tuple[ - utils.Checkpointer, utils.Checkpointer, utils.Checkpointer, utils.Checkpointer + training.Checkpointer, + training.Checkpointer, + training.Checkpointer, + training.Checkpointer, ]: """ Sets up checkpointers for policy, reference policy, value, and reward models. @@ -394,7 +398,7 @@ def _setup_checkpointers( reward_checkpointer, ) - def _setup_model( + def _setup_models( self, cfg_model: DictConfig, cfg_reward_value_model: DictConfig, @@ -409,53 +413,49 @@ def _setup_model( Sets up the policy model, reference policy model, reward model, and value model. """ - with utils.set_default_dtype(self._dtype), self._device: + with training.set_default_dtype(self._dtype), self._device: policy_model = config.instantiate(cfg_model) ref_policy_model = config.instantiate(cfg_model) reward_model = config.instantiate(cfg_reward_value_model) value_model = config.instantiate(cfg_reward_value_model) if enable_activation_checkpointing: - utils.set_activation_checkpointing( - policy_model, auto_wrap_policy={modules.TransformerDecoderLayer} + training.set_activation_checkpointing( + policy_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - utils.set_activation_checkpointing( - value_model, auto_wrap_policy={modules.TransformerDecoderLayer} + training.set_activation_checkpointing( + value_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) policy_model.load_state_dict(policy_state_dict) ref_policy_model.load_state_dict(ref_policy_state_dict) - reward_missing, reward_unexpected = reward_model.load_state_dict( - reward_model_state_dict, strict=False + # since we should be loading a classifier checkpoint into + # a classifier model, this function should just ensure + # output.weight appears in the state_dict and the model's parameters, + # and removes output.bias from the state dict if found + training.update_state_dict_for_classifier( + reward_model_state_dict, reward_model.named_parameters() ) - value_missing, value_unexpected = value_model.load_state_dict( - value_model_state_dict, strict=False - ) - - # some extra validation for HF classifier checkpoints with a `score.bias` present - assert ( - reward_missing == value_missing == [] - ), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts." + reward_model.load_state_dict(reward_model_state_dict) - if reward_unexpected or value_unexpected: - # the only unexpected keys should be when pre-trained HF models were saved with - # bias=True in final classification layers. This happens when training a reward model with TRL. - assert ( - reward_unexpected == value_unexpected == ["output.bias"] - ), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts." + # same as above + training.update_state_dict_for_classifier( + value_model_state_dict, value_model.named_parameters() + ) + value_model.load_state_dict(value_model_state_dict) # Validate models were loaded in with the expected dtype. - utils.validate_expected_param_dtype( + training.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) - utils.validate_expected_param_dtype( + training.validate_expected_param_dtype( reward_model.named_parameters(), dtype=self._dtype ) - utils.validate_expected_param_dtype( + training.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) - utils.validate_expected_param_dtype( + training.validate_expected_param_dtype( ref_policy_model.named_parameters(), dtype=self._dtype ) @@ -497,8 +497,8 @@ def _setup_model( value_model.compile(backend=backend) if self._device.type == "cuda": - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) return policy_model, value_model, reward_model, ref_policy_model @@ -518,17 +518,17 @@ def _setup_optimizer( ) } # Register optimizer step hooks on the models to run optimizer in backward. - utils.register_optim_in_bwd_hooks( + training.register_optim_in_bwd_hooks( model=self._policy_model, optim_dict=optim_dict ) - utils.register_optim_in_bwd_hooks( + training.register_optim_in_bwd_hooks( model=self._value_model, optim_dict=optim_dict ) # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( model=self._policy_model, optim_dict=optim_dict ) - self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( model=self._value_model, optim_dict=optim_dict ) # Load optimizer states. If optimizer states are being restored in an optimizer in backward @@ -582,7 +582,9 @@ def _setup_data( sampler=sampler, batch_size=batch_size, collate_fn=partial( - rlhf.left_padded_collate, + padded_collate, + pad_direction="left", + keys_to_pad=["tokens", "labels"], padding_idx=self._tokenizer.pad_id, ), drop_last=True, @@ -597,25 +599,27 @@ def save_checkpoint( Save state dict to file. The recipe save_checkpoint method is responsible for correctly creating the checkpoint dict and passing to the checkpointer. """ - policy_ckpt_dict = {utils.MODEL_KEY: self._policy_model.state_dict()} - value_ckpt_dict = {utils.MODEL_KEY: self._value_model.state_dict()} + policy_ckpt_dict = {training.MODEL_KEY: self._policy_model.state_dict()} + value_ckpt_dict = {training.MODEL_KEY: self._value_model.state_dict()} # if training is in-progress, checkpoint the optimizer state and rng state as well if is_intermediate_checkpoint: policy_ckpt_dict.update( { - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self._epochs_run, - utils.TOTAL_EPOCHS_KEY: self._total_epochs, - utils.MAX_STEPS_KEY: self._total_steps, - utils.STEPS_KEY: self._steps_run, - utils.RNG_KEY: self._rng.get_state(), + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self._epochs_run, + training.TOTAL_EPOCHS_KEY: self._total_epochs, + training.MAX_STEPS_KEY: self._total_steps, + training.STEPS_KEY: self._steps_run, + training.RNG_KEY: self._rng.get_state(), } ) if not self._optimizer_in_bwd: - policy_ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() + policy_ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict() else: - policy_ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + policy_ckpt_dict[ + training.OPT_KEY + ] = self._optim_ckpt_wrapper.state_dict() self._policy_checkpointer.save_checkpoint( policy_ckpt_dict, @@ -637,20 +641,20 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: # warn the user and overwrite. try: if ( - self.seed != ckpt_dict[utils.SEED_KEY] - or self._total_steps != ckpt_dict[utils.MAX_STEPS_KEY] - or self._total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self.seed != ckpt_dict[training.SEED_KEY] + or self._total_steps != ckpt_dict[training.MAX_STEPS_KEY] + or self._total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY] ): warn( message="""Configured value for seed, total_steps, or total_epochs does not match the value stored in checkpoint.""" ) - self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) - self._rng.set_state(ckpt_dict[utils.RNG_KEY]) - self._steps_run = ckpt_dict[utils.STEPS_KEY] - self._total_steps = ckpt_dict[utils.MAX_STEPS_KEY] - self._total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] - self._epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.seed = training.set_seed(seed=ckpt_dict[training.SEED_KEY]) + self._rng.set_state(ckpt_dict[training.RNG_KEY]) + self._steps_run = ckpt_dict[training.STEPS_KEY] + self._total_steps = ckpt_dict[training.MAX_STEPS_KEY] + self._total_epochs = ckpt_dict[training.TOTAL_EPOCHS_KEY] + self._epochs_run = ckpt_dict[training.EPOCHS_KEY] except KeyError as e: raise KeyError from e( @@ -740,7 +744,7 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: # step 5.1 the scores from the reward model are the logits for the last non-padding token in # each (query, truncated-response) pair - seq_lens = utils.get_unmasked_sequence_lengths(response_padding_masks) + seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) # step 5.2 if configured, apply any penalties for sequences without EOS tokens @@ -828,7 +832,7 @@ def train(self) -> None: self._sampler.set_epoch(curr_epoch) for _, batch in enumerate(self._dataloader): - batch = batch.to(self._device) + batch = batch["tokens"].to(self._device) _, context_length = batch.shape # step 1. generate the trajectory using: @@ -1032,7 +1036,7 @@ def log_metrics( "response_lengths": trajectory.seq_lens.float().mean(), } if self._device.type == "cuda" and self._log_peak_memory_stats: - log_dict.update(utils.get_memory_stats(device=self._device)) + log_dict.update(training.get_memory_stats(device=self._device)) self._metric_logger.log_dict(log_dict, step=self.global_step) diff --git a/benchmarks/llm/recipes/qat_distributed.py b/benchmarks/llm/recipes/qat_distributed.py index 211433835..578669ed8 100755 --- a/benchmarks/llm/recipes/qat_distributed.py +++ b/benchmarks/llm/recipes/qat_distributed.py @@ -6,36 +6,31 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import init_process_group -from torch.distributed.fsdp import ( - CPUOffload, - FullOptimStateDictConfig, - FullStateDictConfig, - FullyShardedDataParallel as FSDP, - StateDictType, -) +from torch.distributed import destroy_process_group, init_process_group + from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler - -from torchtune import config, modules, utils +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.utils.activations import apply_selective_activation_checkpointing +from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training.activations import apply_selective_activation_checkpointing from tqdm import tqdm - log = utils.get_logger("DEBUG") @@ -56,8 +51,11 @@ class QATRecipeDistributed(FTRecipeInterface): weight and activation values to stabilize before fake quantizing them, potentially leading to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU - is not supported. + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -105,12 +103,12 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. """ def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) - self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + self._dtype = training.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( @@ -119,7 +117,7 @@ def __init__(self, cfg: DictConfig) -> None: if ( cfg.get("fsdp_cpu_offload", False) - and cfg.get("fused", False) + and cfg.optimizer.get("fused", False) and not utils.torch_version_ge("2.4.0") ): raise RuntimeError( @@ -133,18 +131,21 @@ def __init__(self, cfg: DictConfig) -> None: # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ + cfg.get("fsdp_sharding_strategy", "FULL_SHARD") + ] self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch @@ -170,28 +171,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[utils.SEED_KEY]: + if self.seed != ckpt_dict[training.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" ) ) - self.seed = ckpt_dict[utils.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -207,8 +208,8 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def setup(self, cfg: DictConfig) -> None: """ - Sets up the recipe state correctly. This includes setting recipe attributes based - on the ``resume_from_checkpoint`` flag. + Setup the recipe. This includes training state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, sampler, and dataloader. """ if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -216,34 +217,48 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - ckpt_dict = self.load_checkpoint(cfg.checkpointer) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - # ``_setup_model`` handles initialization and loading the state dict. This method - # should be called before ``_setup_optimizer`` since transforming the optimizer - # state dict requires the model + self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - model_state_dict=ckpt_dict[utils.MODEL_KEY], + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + model_state_dict=checkpoint_dict[training.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), quantizer_cfg=cfg.get("quantizer", None), ) - self._tokenizer = config.instantiate(cfg.tokenizer) - # _setup_optimizer should take in ckpt_dict only if training is resumed from - # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=ckpt_dict[utils.OPT_KEY] + opt_state_dict=checkpoint_dict[training.OPT_KEY] if self._resume_from_checkpoint else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -270,12 +285,89 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - memory_efficient_fsdp_wrap: bool, + custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, + reshard_after_forward: bool, model_state_dict: Dict[str, Any], ac_mode: Optional[str] = None, ac_option: Optional[int] = None, @@ -283,37 +375,20 @@ def _setup_model( ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we load the model on CPU with the right - dtype. To ensure that we don't instantiate ``world_size`` number of models, - we initialize on meta_device for all ranks other than rank 0. - b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the - model weights from checkpoint. - c. While wrapping the model with FSDP, we set ``sync_module_states`` - to TRUE and broadcast module params and buffers from rank 0. - d. The ``device_id`` param ensures that the FSDP initialization happens on - the correct device. + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` """ - if self._is_rank_zero: - log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") - init_start = time.perf_counter() - - with utils.set_default_dtype(self._dtype): - model = config.instantiate(cfg_model) + if self._is_rank_zero: log.info( - f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." ) + init_start = time.perf_counter() - # Load both the model weights. This should happen only on Rank 0 - model.load_state_dict(model_state_dict) - - else: - # For non-zero ranks, load the model on meta device - with utils.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) - - if self._dtype == torch.bfloat16: - model = model.to(torch.bfloat16) + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls @@ -321,9 +396,6 @@ def _setup_model( # ac_mode and ac_option together control selective AC. This is only enabled # when these are set AND ``enable_activation_checkpointing`` is set to False # We'll clean this up as soon as testing of AC is complete - ac_mode = ac_mode - ac_option = ac_option - if (not enable_activation_checkpointing) and (ac_mode is not None): apply_selective_activation_checkpointing( model, @@ -331,12 +403,18 @@ def _setup_model( ac_option, ) + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + # Apply quantization-aware training during finetuning if quantizer_cfg is None: raise ValueError("Quantizer must be specified for QAT recipe.") quantizer = config.instantiate(quantizer_cfg) quantizer.precision = self._dtype - quantizer_mode = utils.quantization.get_quantizer_mode(quantizer) + quantizer_mode = training.quantization.get_quantizer_mode(quantizer) if "qat" not in quantizer_mode: raise ValueError( "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode @@ -344,43 +422,57 @@ def _setup_model( self._quantizer_mode = quantizer_mode model = quantizer.prepare(model) - # Wrap the model with FSDP. This will ensure that the model is sharded - # across all available GPUs. - model = FSDP( - module=model, - auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( - memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, - modules_to_wrap={modules.TransformerDecoderLayer}, - ), - cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), - sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, - device_id=self._device, - # this recipe does not currently support mixed precision training - mixed_precision=None, - # Ensure we broadcast params and buffers from rank 0 - sync_module_states=True, - # Initialize empty modules on all non-zero ranks - param_init_fn=( - lambda module: module.to_empty( - device=torch.device("cuda"), recurse=False - ) - if not self._is_rank_zero - else None - ), + # For FSDP sharding, we can condition on either the module or its name + # Shard conditions should be callables taking name (relative to model root) + # and the module itself and returning a bool on whether to shard the given module + fsdp_shard_conditions = [] + + # Shard transformer decoder layers (or AC-wrapped versions) + # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) + # But directly using the name is more concise + def _is_layer_fqn(s: str) -> bool: + """ + Return True for layers.i and False for all other module names + Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot + """ + s_list = s.split(".") + return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) + + fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] + + # If wrapping any layers separately, we can add another shard condition + # A layer will be sharded if any of the fsdp_shard_conditions are met + if custom_sharded_layers: + fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] + + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, ) - # Ensure no params and buffers are on meta device - utils.validate_no_params_on_meta_device(model) + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() - # original activation checkpointing (full) - flip the condition above - if enable_activation_checkpointing and ac_mode is None: - utils.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerDecoderLayer} - ) + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, model_state_dict, self._device, self._is_rank_zero, strict=True + ) + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) if self._is_rank_zero: - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -390,17 +482,13 @@ def _setup_model( def _setup_optimizer( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: - """ - Set up the optimizer. This method also handles transforing the state dict - for FSDP. - """ optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - opt_state_dict = FSDP.optim_state_dict_to_load( - self._model, optimizer, opt_state_dict + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, ) - optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: log.info("Optimizer is initialized.") @@ -417,7 +505,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = utils.get_world_size_and_rank() + world_size, rank = training.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -431,23 +519,21 @@ def _setup_data( packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( - ds, - num_replicas=world_size, - rank=rank, - shuffle=shuffle, - seed=0, + ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 ) dataloader = DataLoader( dataset=ds, batch_size=batch_size, sampler=sampler, collate_fn=partial( - utils.padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else None, + else partial( + padded_collate_packed, + ), ) if self._is_rank_zero: @@ -455,57 +541,72 @@ def _setup_data( return sampler, dataloader - def save_checkpoint(self, epoch: int) -> None: + def save_checkpoint( + self, + epoch: int, + ) -> None: """ - Save state dict to file. The recipe save_checkpoint method is responsible for - correctly creating the checkpoint dict and passing to the checkpointer. + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Model weights with key training.MODEL_KEY + - Relevant recipe state if training is not complete + + Checkpointer will save the model weights and recipe state in + different checkpoint files. To correctly resume training from an intermediate checkpoint, + the model weights and recipe state must be provided. """ + # final dict passed onto the checkpointer checkpoint_dict = {} + intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - with FSDP.state_dict_type( + cpu_state_dict = training.get_full_model_state_dict( self._model, - StateDictType.FULL_STATE_DICT, - FullStateDictConfig(offload_to_cpu=True, rank0_only=True), - FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), - ): - cpu_state_dict = self._model.state_dict() - opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + self._is_rank_zero, + ) + + if intermediate_checkpoint: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + ) + else: + opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) + checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) - # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well. + if intermediate_checkpoint: checkpoint_dict.update( { - utils.OPT_KEY: opt_state_dict, - utils.SEED_KEY: self.seed, - utils.EPOCHS_KEY: self.epochs_run, - utils.TOTAL_EPOCHS_KEY: self.total_epochs, - utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), + intermediate_checkpoint=intermediate_checkpoint, ) def train(self) -> None: """ - The core training loop. Supports training on subsets of the dataset using the - ``max_steps_per_epoch``. + The core training loop. """ # clean up before training begins - utils.cleanup_before_training() + training.cleanup_before_training() - _, rank = utils.get_world_size_and_rank() + _, rank = training.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -515,6 +616,7 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 + self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -531,6 +633,15 @@ def train(self) -> None: ): break + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -545,7 +656,7 @@ def train(self) -> None: "Step 0: Disabling fake quant, will re-enable in step %s" % self._fake_quant_after_n_steps ) - disable_fq = utils.quantization._get_disable_fake_quant( + disable_fq = training.quantization._get_disable_fake_quant( self._quantizer_mode ) self._model.apply(disable_fq) @@ -554,7 +665,7 @@ def train(self) -> None: "Step %s: Enabling fake quant" % self._fake_quant_after_n_steps ) - enable_fq = utils.quantization._get_enable_fake_quant( + enable_fq = training.quantization._get_enable_fake_quant( self._quantizer_mode ) self._model.apply(enable_fq) @@ -568,12 +679,21 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory + del logits loss = loss / self._gradient_accumulation_steps running_loss += loss @@ -590,7 +710,7 @@ def train(self) -> None: loss_to_log = running_loss.item() pbar.update(1) pbar.set_description( - f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -605,7 +725,9 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update(utils.get_memory_stats(device=self._device)) + log_dict.update( + training.get_memory_stats(device=self._device) + ) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -616,13 +738,32 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step profiler + # Note that this is called within gradient accumulation block, hence + # will include multiple forward / backward passes if gradient accumulation > 1 + self._profiler.step() + self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) + self._profiler.stop() + def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() - torch.distributed.destroy_process_group() + destroy_process_group() @config.parse @@ -634,17 +775,16 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not utils.is_distributed(): + if not training.is_distributed(): raise RuntimeError( "Distributed QAT recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU - utils.set_torch_num_threads() + training.set_torch_num_threads() config.log_config(recipe_name="QATRecipeDistributed", cfg=cfg) diff --git a/benchmate/benchmate/monitor.py b/benchmate/benchmate/monitor.py index 0ad34a3d3..a7f1dd0f3 100644 --- a/benchmate/benchmate/monitor.py +++ b/benchmate/benchmate/monitor.py @@ -126,7 +126,6 @@ def monogpu_monitor(*args, **kwargs): yield log - @contextmanager def bench_monitor(*args, **kwargs): if int(os.getenv("RANK", -1)) == -1: From 6caac29e4c75e9fff050969194672e77db1187bd Mon Sep 17 00:00:00 2001 From: Satya Ortiz-Gagne Date: Mon, 23 Sep 2024 16:49:17 -0400 Subject: [PATCH 3/7] Fix rlhf on trl v0.11.0 --- benchmarks/rlhf/main.py | 4 ++-- benchmarks/rlhf/prepare.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/rlhf/main.py b/benchmarks/rlhf/main.py index 0be12d282..a46b9579b 100755 --- a/benchmarks/rlhf/main.py +++ b/benchmarks/rlhf/main.py @@ -13,7 +13,7 @@ from trl import ModelConfig from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer -from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE class PPOv2TrainerIntrumented(PPOv2Trainer): @@ -62,7 +62,7 @@ def main(): ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 ) diff --git a/benchmarks/rlhf/prepare.py b/benchmarks/rlhf/prepare.py index 4c9aa631f..5e4cb4eba 100755 --- a/benchmarks/rlhf/prepare.py +++ b/benchmarks/rlhf/prepare.py @@ -11,7 +11,7 @@ from datasets import load_dataset from trl import ModelConfig from trl.trainer.ppov2_trainer import PPOv2Config -from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE if __name__ == "__main__": @@ -30,7 +30,7 @@ tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( config.reward_model_path, From 6a1c120d2af1f2eb9b3904f2ea9a3a14e33f900c Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 25 Sep 2024 21:51:34 -0400 Subject: [PATCH 4/7] Add missing monitor tag --- config/base.yaml | 2 ++ milabench/_version.py | 6 +++--- milabench/cli/compare.py | 7 +++++-- milabench/compare.py | 14 ++++++++++++-- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/config/base.yaml b/config/base.yaml index d6a1dc5fb..d7926799f 100644 --- a/config/base.yaml +++ b/config/base.yaml @@ -881,6 +881,8 @@ cleanrljax: inherits: _defaults install_group: torch definition: ../benchmarks/cleanrl_jax + tags: + - monogpu plan: method: per_gpu diff --git a/milabench/_version.py b/milabench/_version.py index 4b49d0506..4da614fc7 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v0.1.0-113-g9a5dfe3e" -__commit__ = "9a5dfe3ef36e6baab6584faa3fa939e63ba2aed5" -__date__ = "2024-09-16 09:08:28 -0400" +__tag__ = "v1.0.0_RC1-6-g4639c19d" +__commit__ = "4639c19d34c8b15349e2dd265a493374c84ed3aa" +__date__ = "2024-09-25 15:22:55 +0000" diff --git a/milabench/cli/compare.py b/milabench/cli/compare.py index b2992857c..83f0c59ce 100644 --- a/milabench/cli/compare.py +++ b/milabench/cli/compare.py @@ -15,6 +15,7 @@ class Arguments: last : int = None metric : str = "train_rate" stat : str = "median" + filter : str = None # fmt: on @@ -23,13 +24,15 @@ def arguments(): # [positional: ?] folder: Option = None + filter: Option & str = None + last: Option & int = None metric: Option & str = "train_rate" stat: Option & str = "median" - return Arguments(folder, last, metric, stat) + return Arguments(folder, last, metric, stat, filter) @tooled @@ -66,7 +69,7 @@ def cli_compare(args=None): if base is not None: args.folder = os.path.join(base, "runs") - runs = fetch_runs(args.folder) + runs = fetch_runs(args.folder, args.filter) for run in runs: all_data = _read_reports(run.path) diff --git a/milabench/compare.py b/milabench/compare.py index e3b88b10c..cae068203 100644 --- a/milabench/compare.py +++ b/milabench/compare.py @@ -21,14 +21,22 @@ def retrieve_datetime_from_name(date): pass -def fetch_runs(folder): +def fetch_runs(folder, filter): + import fnmatch + runs = [] + ignored = 0 for run in os.listdir(folder): + if filter is not None and (not fnmatch.fnmatch(run, filter)): + ignored += 1 + continue + pth = os.path.join(folder, run) if not os.path.isdir(pth): continue if "." in run: - name, date = run.split(".", maxsplit=1) + name, fractional_seconds = run.rsplit(".", maxsplit=1) + name, date = name.rsplit(".", maxsplit=1) date = retrieve_datetime_from_name(date) else: name = run @@ -39,6 +47,8 @@ def fetch_runs(folder): out = _Output(pth, name, date) runs.append(out) + if ignored > 0: + print(f"Ignoring run {ignored} runs because of filter {filter}") runs.sort(key=lambda out: out.date) return runs From 770ca62e578875c4ccfe146e91dc65e0de0948df Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Mon, 30 Sep 2024 10:17:36 -0400 Subject: [PATCH 5/7] Revert "Fix llm with torchtune v0.3" This reverts commit fd99d8ab642519da5d4ff223c7a2de535f4d41bc. --- benchmarks/llm/configs/llama3_70B_full.yaml | 11 +- benchmarks/llm/configs/llama3_70B_lora.yaml | 10 +- benchmarks/llm/configs/llama3_8B_lora.yaml | 9 +- .../configs/llama3_8B_lora_single_device.yaml | 11 +- .../llm/configs/llama3_8B_qat_full.yaml | 9 +- .../llama3_8B_qlora_single_device.yaml | 14 +- .../llm/recipes/full_finetune_distributed.py | 447 ++++++----------- .../recipes/full_finetune_single_device.py | 269 +++-------- benchmarks/llm/recipes/generate.py | 24 +- .../llm/recipes/lora_finetune_distributed.py | 391 ++++++--------- .../recipes/lora_finetune_single_device.py | 261 ++++------ .../ppo_full_finetune_single_device.py | 134 +++--- benchmarks/llm/recipes/qat_distributed.py | 448 ++++++------------ benchmate/benchmate/monitor.py | 1 + 14 files changed, 711 insertions(+), 1328 deletions(-) diff --git a/benchmarks/llm/configs/llama3_70B_full.yaml b/benchmarks/llm/configs/llama3_70B_full.yaml index 2cfe8ec92..ae5cf2afb 100644 --- a/benchmarks/llm/configs/llama3_70B_full.yaml +++ b/benchmarks/llm/configs/llama3_70B_full.yaml @@ -20,7 +20,6 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model - max_seq_len: null # Dataset dataset: @@ -34,7 +33,7 @@ model: safetensors: true checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer + _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ model-00001-of-00030.safetensors, @@ -86,7 +85,7 @@ optimizer: fused: True loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 @@ -96,7 +95,7 @@ device: cuda # Memory management enable_activation_checkpointing: True -custom_sharded_layers: ['tok_embeddings', 'output'] +memory_efficient_fsdp_wrap: True fsdp_cpu_offload: True # Reduced precision @@ -104,8 +103,8 @@ dtype: bf16 # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} -output_dir: /tmp/full-llama3_1-finetune +output_dir: /tmp/alpaca-llama3-finetune log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_70B_lora.yaml b/benchmarks/llm/configs/llama3_70B_lora.yaml index 9a8f1680a..7821e174a 100644 --- a/benchmarks/llm/configs/llama3_70B_lora.yaml +++ b/benchmarks/llm/configs/llama3_70B_lora.yaml @@ -16,16 +16,14 @@ model: apply_lora_to_output: False lora_rank: 16 lora_alpha: 32 - lora_dropout: 0.0 tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model - max_seq_len: null safetensors: true checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer + _component_: torchtune.utils.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ checkpoint_files: [ model-00001-of-00030.safetensors, @@ -82,7 +80,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 @@ -90,9 +88,9 @@ max_steps_per_epoch: null gradient_accumulation_steps: 1 # Logging -output_dir: /tmp/lora-llama3_1-finetune-output +output_dir: /tmp/lora_finetune_output metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_8B_lora.yaml b/benchmarks/llm/configs/llama3_8B_lora.yaml index f499b712c..7bae8d036 100644 --- a/benchmarks/llm/configs/llama3_8B_lora.yaml +++ b/benchmarks/llm/configs/llama3_8B_lora.yaml @@ -21,7 +21,6 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model - max_seq_len: null # Model Arguments model: @@ -31,10 +30,9 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 - lora_dropout: 0.0 checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer + _component_: torchtune.utils.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -43,7 +41,6 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False -save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -62,7 +59,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 @@ -72,7 +69,7 @@ gradient_accumulation_steps: 32 # Logging output_dir: /tmp/lora_finetune_output metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml index f5d8e3efa..b341f5afc 100644 --- a/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml +++ b/benchmarks/llm/configs/llama3_8B_lora_single_device.yaml @@ -24,16 +24,14 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 - lora_dropout: 0.0 # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model - max_seq_len: null checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer + _component_: torchtune.utils.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -42,7 +40,6 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False -save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -61,7 +58,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 @@ -72,7 +69,7 @@ compile: False # Logging output_dir: /tmp/lora_finetune_output metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False @@ -84,7 +81,7 @@ enable_activation_checkpointing: True # Profiler (disabled) profiler: - _component_: torchtune.training.setup_torch_profiler + _component_: torchtune.utils.setup_torch_profiler enabled: False #Output directory of trace artifacts diff --git a/benchmarks/llm/configs/llama3_8B_qat_full.yaml b/benchmarks/llm/configs/llama3_8B_qat_full.yaml index c9d99f98a..23f60f779 100644 --- a/benchmarks/llm/configs/llama3_8B_qat_full.yaml +++ b/benchmarks/llm/configs/llama3_8B_qat_full.yaml @@ -17,7 +17,6 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model - max_seq_len: null # Dataset dataset: @@ -30,7 +29,7 @@ model: _component_: torchtune.models.llama3_1.llama3_1_8b checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer + _component_: torchtune.utils.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -46,7 +45,7 @@ epochs: 3 # QAT arguments quantizer: - _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer + _component_: torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer groupsize: 256 optimizer: @@ -55,7 +54,7 @@ optimizer: foreach: False loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 @@ -71,7 +70,7 @@ dtype: bf16 # Logging metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} output_dir: /tmp/alpaca-llama3-finetune log_every_n_steps: 1 diff --git a/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml b/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml index 1f50aa9d4..fc30f458f 100644 --- a/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml +++ b/benchmarks/llm/configs/llama3_8B_qlora_single_device.yaml @@ -23,16 +23,14 @@ model: apply_lora_to_output: False lora_rank: 8 lora_alpha: 16 - lora_dropout: 0.0 # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model - max_seq_len: null checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer + _component_: torchtune.utils.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3-8B-Instruct/original/ checkpoint_files: [ consolidated.00.pth @@ -41,7 +39,6 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False -save_adapter_weights_only: False # Dataset and Sampler dataset: @@ -60,7 +57,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + _component_: torch.nn.CrossEntropyLoss # Training epochs: 1 @@ -71,7 +68,7 @@ compile: False # Logging output_dir: /tmp/qlora_finetune_output/ metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger + _component_: torchtune.utils.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False @@ -83,7 +80,7 @@ enable_activation_checkpointing: True # Profiler (disabled) profiler: - _component_: torchtune.training.setup_torch_profiler + _component_: torchtune.utils.setup_torch_profiler enabled: False #Output directory of trace artifacts @@ -105,6 +102,3 @@ profiler: warmup_steps: 5 active_steps: 2 num_cycles: 1 - -# For colab use True -low_cpu_ram: False diff --git a/benchmarks/llm/recipes/full_finetune_distributed.py b/benchmarks/llm/recipes/full_finetune_distributed.py index a46ff0a91..3a51842da 100755 --- a/benchmarks/llm/recipes/full_finetune_distributed.py +++ b/benchmarks/llm/recipes/full_finetune_distributed.py @@ -10,26 +10,32 @@ import time from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from warnings import warn import torch from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import destroy_process_group, init_process_group - +from torch.distributed import init_process_group +from torch.distributed.fsdp import ( + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft + +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY -from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.utils.activations import apply_selective_activation_checkpointing from tqdm import tqdm + log = utils.get_logger("DEBUG") @@ -39,11 +45,8 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): distributed training and can be run on a single node (1 to 8 GPUs). Features: - - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is - done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config - ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). - DDP is currently not supported. Training on CPU is not supported. + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU + is not supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -91,12 +94,12 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. - RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. """ def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( @@ -119,7 +122,7 @@ def __init__(self, cfg: DictConfig) -> None: # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg @@ -128,7 +131,7 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch @@ -154,28 +157,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: + if self.seed != ckpt_dict[utils.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" ) ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -191,8 +194,8 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def setup(self, cfg: DictConfig) -> None: """ - Setup the recipe. This includes training state (if resume_from_checkpoint is True), - model, tokenizer, loss, optimizer, sampler, and dataloader. + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. """ if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -200,41 +203,34 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + ckpt_dict = self.load_checkpoint(cfg.checkpointer) - self._compile = cfg.get("compile", False) + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - custom_sharded_layers=cfg.get("custom_sharded_layers", None), + memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), - model_state_dict=checkpoint_dict[training.MODEL_KEY], + model_state_dict=ckpt_dict[utils.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), ) + self._tokenizer = config.instantiate(cfg.tokenizer) + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[training.OPT_KEY] + opt_state_dict=ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None, ) - # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._compile: - training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) - - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - - if self._is_rank_zero: - log.info("Loss is initialized.") - # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized self._sampler, self._dataloader = self._setup_data( @@ -260,109 +256,49 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch - # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) - # if cfg is missing profiler key or if `cfg.profiler.enabled = False` - self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None - ) -> Union[torch.profiler.profile, DummyProfiler]: - """ - Parses the `profiler` section of top-level `cfg` and sets up profiler - - Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. - - Returns: - profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods - for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such - that the instrumented training loop does not need to be changed profiling is disabled. - - The profiler config can be provided in configs under the `profiler` key with the following layout: - - .. code-block:: yaml - profiler: - enabled: bool - - #Output directory of trace artifacts - output_dir: str - - #`torch.profiler.ProfilerActivity` types to trace - cpu: bool - cuda: bool - - #Trace options - profile_memory: bool - with_stack: bool - record_shapes: bool - with_flops: bool - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: int - warmup_steps: int - active_steps: int - num_cycles: int - """ - # Missing profiler section in config, assume disabled - if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) - - # Check that component is included and set correctly - if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" - else: - assert ( - cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" - - profiler, profiler_cfg = config.instantiate(cfg_profiler) - - if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - - return profiler - def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - custom_sharded_layers: Optional[List[str]], + memory_efficient_fsdp_wrap: bool, fsdp_cpu_offload: bool, - reshard_after_forward: bool, model_state_dict: Dict[str, Any], ac_mode: Optional[str] = None, ac_option: Optional[int] = None, ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we initialize the model on meta device with - the right dtype - b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since - full state dicts are loaded with ``torch.load(mmap=True)`` + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. """ - if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" ) - init_start = time.perf_counter() - with training.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) + # Load both the model weights. This should happen only on Rank 0 + model.load_state_dict(model_state_dict) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls @@ -370,6 +306,9 @@ def _setup_model( # ac_mode and ac_option together control selective AC. This is only enabled # when these are set AND ``enable_activation_checkpointing`` is set to False # We'll clean this up as soon as testing of AC is complete + ac_mode = ac_mode + ac_option = ac_option + if (not enable_activation_checkpointing) and (ac_mode is not None): apply_selective_activation_checkpointing( model, @@ -377,68 +316,43 @@ def _setup_model( ac_option, ) - # original activation checkpointing (full) - flip the condition above - if enable_activation_checkpointing and ac_mode is None: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} - ) - - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - fsdp_shard_conditions = [] - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_fqn(s: str) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - s_list = s.split(".") - return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) - - fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] - - # If wrapping any layers separately, we can add another shard condition - # A layer will be sharded if any of the fsdp_shard_conditions are met - if custom_sharded_layers: - fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] - - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - ) - - with training.set_default_dtype(self._dtype), self._device: - for m in model.modules(): - # RoPE is not covered in state dict - if hasattr(m, "rope_init"): - m.rope_init() - - # This method will convert the full model state dict into a sharded state - # dict and load into the model - training.load_from_full_model_state_dict( - model, - model_state_dict, - self._device, - self._is_rank_zero, - strict=True, - cpu_offload=fsdp_cpu_offload, + # Wrap the model with FSDP. This will ensure that the model is sharded + # across all available GPUs. + model = FSDP( + module=model, + auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, + modules_to_wrap={modules.TransformerDecoderLayer}, + ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if not self._is_rank_zero + else None + ), ) # Ensure no params and buffers are on meta device - training.validate_no_params_on_meta_device(model) + utils.validate_no_params_on_meta_device(model) - if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} ) - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -448,13 +362,17 @@ def _is_layer_fqn(s: str) -> bool: def _setup_optimizer( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: + """ + Set up the optimizer. This method also handles transforing the state dict + for FSDP. + """ optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict ) + optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: log.info("Optimizer is initialized.") @@ -471,7 +389,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -485,21 +403,23 @@ def _setup_data( packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( - ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, ) dataloader = DataLoader( dataset=ds, batch_size=batch_size, sampler=sampler, collate_fn=partial( - padded_collate_sft, + utils.padded_collate, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else None, ) if self._is_rank_zero: @@ -507,74 +427,57 @@ def _setup_data( return sampler, dataloader - def save_checkpoint( - self, - epoch: int, - ) -> None: + def save_checkpoint(self, epoch: int) -> None: """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Model weights with key training.MODEL_KEY - - Relevant recipe state if training is not complete - - Checkpointer will save the model weights and recipe state in - different checkpoint files. To correctly resume training from an intermediate checkpoint, - the model weights and recipe state must be provided. + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. """ - # final dict passed onto the checkpointer checkpoint_dict = {} - intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( + with FSDP.state_dict_type( self._model, - self._is_rank_zero, - device=self._device, - ) - - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, - device=self._device, - ) - else: - opt_state_dict = None + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) - # if training is in-progress, checkpoint the optimizer state and recipe state - # as well. - if intermediate_checkpoint: + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: checkpoint_dict.update( { - training.OPT_KEY: opt_state_dict, - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), ) def train(self) -> None: """ - The core training loop. + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. """ # clean up before training begins - training.cleanup_before_training() + utils.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -584,7 +487,6 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 - self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -601,15 +503,6 @@ def train(self) -> None: ): break - # Start tracking CUDA memory for active steps for just the first epoch - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -626,25 +519,13 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) # Compute loss loss = self._loss_fn(logits, labels) - - # free logits otherwise it peaks backward memory - del logits - - loss = loss / self._gradient_accumulation_steps - running_loss += loss + running_loss += loss / self._gradient_accumulation_steps loss.backward() # Step with optimizer @@ -659,7 +540,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -674,9 +555,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) + log_dict.update(utils.get_memory_stats(device=self._device)) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -687,38 +566,18 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() - # Stop tracking CUDA memory now that active steps are complete - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - - # Step profiler - # Note that this is called within gradient accumulation block, hence - # will include multiple forward / backward passes if gradient accumulation > 1 - self._profiler.step() - self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) - self._profiler.stop() - def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() - destroy_process_group() - + torch.distributed.destroy_process_group() def log_loss(self, loss): pass -def prepare_voir(recipe:FullFinetuneRecipeDistributed): +def prepare_voir(recipe): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor @@ -743,6 +602,7 @@ def on_loss(loss): return observer, bench_monitor + @config.parse def recipe_main(cfg: DictConfig) -> None: """ @@ -752,22 +612,22 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not training.is_distributed(): + if not utils.is_distributed(): raise RuntimeError( "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU - training.set_torch_num_threads() + utils.set_torch_num_threads() config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) recipe = FullFinetuneRecipeDistributed(cfg=cfg) recipe.setup(cfg=cfg) - from voir.phase import StopProgram try: _, monitor = prepare_voir(recipe) @@ -775,7 +635,6 @@ def recipe_main(cfg: DictConfig) -> None: recipe.train() except StopProgram: print("early stopping") - recipe.cleanup() diff --git a/benchmarks/llm/recipes/full_finetune_single_device.py b/benchmarks/llm/recipes/full_finetune_single_device.py index f4d0df7cf..98322579f 100755 --- a/benchmarks/llm/recipes/full_finetune_single_device.py +++ b/benchmarks/llm/recipes/full_finetune_single_device.py @@ -6,10 +6,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from warnings import warn import torch @@ -19,11 +20,9 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -85,10 +84,6 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. - - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, - ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set - ``clip_grad_norm='inf'``. - For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -103,7 +98,7 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor # enabled necessary features such as gradient scaling. if self._dtype == torch.float16: @@ -131,12 +126,11 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -158,28 +152,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: + if self.seed != ckpt_dict[utils.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" ) ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -208,12 +202,12 @@ def setup(self, cfg: DictConfig) -> None: # ``_setup_model`` handles initialization and loading the state dict. This method # should be called before ``_setup_optimizer`` since transforming the optimizer # state dict requires the model - self._compile = cfg.compile + self._model_compile = cfg.compile self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=self._compile, - model_state_dict=ckpt_dict[training.MODEL_KEY], + compile_model=self._model_compile, + model_state_dict=ckpt_dict[utils.MODEL_KEY], ) self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -224,20 +218,11 @@ def setup(self, cfg: DictConfig) -> None: cfg_optimizer=cfg.optimizer, optimizer_in_bwd=cfg.optimizer_in_bwd, opt_state_dict=( - ckpt_dict[training.OPT_KEY] if self._resume_from_checkpoint else None + ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None ), ) - # initialize loss self._loss_fn = config.instantiate(cfg.loss) - - if self._compile: - training.compile_loss(self._loss_fn) - - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be @@ -265,82 +250,6 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch - # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) - # if cfg is missing profiler key or if `cfg.profiler.enabled = False` - self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None - ) -> Union[torch.profiler.profile, DummyProfiler]: - """ - Parses the `profiler` section of top-level `cfg` and sets up profiler - - Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. - - Returns: - profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods - for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such - that the instrumented training loop does not need to be changed profiling is disabled. - - The profiler config can be provided in configs under the `profiler` key with the following layout: - - .. code-block:: yaml - profiler: - enabled: bool - - #Output directory of trace artifacts - output_dir: str - - #`torch.profiler.ProfilerActivity` types to trace - cpu: bool - cuda: bool - - #Trace options - profile_memory: bool - with_stack: bool - record_shapes: bool - with_flops: bool - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: int - warmup_steps: int - active_steps: int - num_cycles: int - """ - - # Missing profiler section in config, assume disabled - if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) - - # Check that component is included and set correctly - if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" - else: - assert ( - cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" - - profiler, profiler_cfg = config.instantiate(cfg_profiler) - - log.info(f" Profiler config after instantiation: {profiler_cfg}") - - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - - return profiler - def _setup_model( self, cfg_model: DictConfig, @@ -351,28 +260,28 @@ def _setup_model( """ Set up the model including enabling activation checkpointing. """ - with training.set_default_dtype(self._dtype), self._device: + with utils.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) - if compile_model: - training.compile_model(model) - if enable_activation_checkpointing: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} ) model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. - training.validate_expected_param_dtype( - model.named_parameters(), dtype=self._dtype - ) + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) log.info(f"Model is initialized with precision {self._dtype}.") + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model.compile(backend=backend) if self._device.type == "cuda": - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) return model @@ -392,11 +301,9 @@ def _setup_optimizer( for p in self._model.parameters() } # Register optimizer step hooks on the model to run optimizer in backward. - training.register_optim_in_bwd_hooks( - model=self._model, optim_dict=optim_dict - ) + utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict) # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( model=self._model, optim_dict=optim_dict ) # Load optimizer states. If optimizer states are being restored in an optimizer in backward @@ -433,13 +340,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, self._tokenizer) + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( @@ -454,14 +361,12 @@ def _setup_data( batch_size=batch_size, sampler=sampler, collate_fn=partial( - padded_collate_sft, + utils.padded_collate, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else None, ) log.info("Dataset and Sampler are initialized.") @@ -473,60 +378,33 @@ def save_checkpoint(self, epoch: int) -> None: Save state dict to file. The recipe save_checkpoint method is responsible for correctly creating the checkpoint dict and passing to the checkpointer. """ - ckpt_dict = {training.MODEL_KEY: self._model.state_dict()} + ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()} # if training is in-progress, checkpoint the optimizer state as well if epoch + 1 < self.total_epochs: ckpt_dict.update( { - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) if not self._optimizer_in_bwd: - ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict() + ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() else: - ckpt_dict[training.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() + ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, intermediate_checkpoint=(epoch + 1 < self.total_epochs), ) - def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - - # Compute loss - loss = self._loss_fn(logits, labels) - # free logits otherwise it peaks backward memory - del logits - - return loss - def train(self) -> None: """ The core training loop. Supports training on subsets of the dataset using the ``max_steps_per_epoch``. """ - if self._compile: + if self._model_compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." ) @@ -539,7 +417,6 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 - self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs @@ -555,29 +432,35 @@ def train(self) -> None: ): break - # Start tracking CUDA memory for active steps for just the first epoch - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) - batch = {k: v.to(self._device) for k, v in batch.items()} - num_tokens += batch["tokens"].numel() + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) - loss = self._loss_step(batch) loss = loss / self._gradient_accumulation_steps running_loss += loss loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) if not self._optimizer_in_bwd: self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) @@ -587,7 +470,7 @@ def train(self) -> None: loss_to_log = running_loss.item() pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -605,11 +488,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._device.type == "cuda" and self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) + log_dict.update(utils.get_memory_stats(device=self._device)) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -620,27 +499,9 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() - # Stop tracking CUDA memory now that active steps are complete - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - - # Step the profiler - # Note we are stepping each batch, which might not include optimizer step in the trace - # if the schedule cycle doesn't align with gradient accumulation. - self._profiler.step() - self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) - self._profiler.stop() - def cleanup(self) -> None: self._metric_logger.close() diff --git a/benchmarks/llm/recipes/generate.py b/benchmarks/llm/recipes/generate.py index 7334d81b0..883a75444 100755 --- a/benchmarks/llm/recipes/generate.py +++ b/benchmarks/llm/recipes/generate.py @@ -14,7 +14,7 @@ from omegaconf import DictConfig from torch import nn -from torchtune import config, generation, training, utils +from torchtune import config, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message @@ -38,11 +38,11 @@ class InferenceRecipe: def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(dtype=cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(dtype=cfg.dtype) self._quantizer = config.instantiate(cfg.quantizer) - self._quantization_mode = training.get_quantizer_mode(self._quantizer) + self._quantization_mode = utils.get_quantizer_mode(self._quantizer) - training.set_seed(seed=cfg.seed) + utils.set_seed(seed=cfg.seed) def setup(self, cfg: DictConfig) -> None: checkpointer = config.instantiate(cfg.checkpointer) @@ -56,7 +56,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( model_cfg=cfg.model, - model_state_dict=ckpt_dict[training.MODEL_KEY], + model_state_dict=ckpt_dict[utils.MODEL_KEY], enable_kv_cache=cfg.enable_kv_cache, ) self._tokenizer = config.instantiate(cfg.tokenizer) @@ -67,7 +67,7 @@ def _setup_model( model_state_dict: Dict[str, Any], enable_kv_cache: bool = True, ) -> nn.Module: - with training.set_default_dtype(self._dtype), self._device: + with utils.set_default_dtype(self._dtype), self._device: model = config.instantiate(model_cfg) if self._quantization_mode is not None: @@ -77,9 +77,7 @@ def _setup_model( model.load_state_dict(model_state_dict) # Validate model was loaded in with the expected dtype. - training.validate_expected_param_dtype( - model.named_parameters(), dtype=self._dtype - ) + utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) logger.info(f"Model is initialized with precision {self._dtype}.") # Ensure the cache is setup on the right device @@ -149,29 +147,31 @@ def generate(self, cfg: DictConfig): if self._quantization_mode is not None: logger.info("Starting compilation to improve generation performance ...") custom_generate_next_token = torch.compile( - generation.generate_next_token, mode="max-autotune", fullgraph=True + utils.generate_next_token, mode="max-autotune", fullgraph=True ) t0 = time.perf_counter() - _ = generation.generate( + _ = utils.generate( model=self._model, prompt=prompt, max_generated_tokens=2, temperature=cfg.temperature, top_k=cfg.top_k, stop_tokens=self._tokenizer.stop_tokens, + pad_id=self._tokenizer.pad_id, custom_generate_next_token=custom_generate_next_token, ) t = time.perf_counter() - t0 logger.info(f"Warmup run for quantized model takes: {t:.02f} sec") t0 = time.perf_counter() - generated_tokens = generation.generate( + generated_tokens = utils.generate( model=self._model, prompt=prompt, max_generated_tokens=cfg.max_new_tokens, temperature=cfg.temperature, top_k=cfg.top_k, stop_tokens=self._tokenizer.stop_tokens, + pad_id=self._tokenizer.pad_id, custom_generate_next_token=custom_generate_next_token, ) t = time.perf_counter() - t0 diff --git a/benchmarks/llm/recipes/lora_finetune_distributed.py b/benchmarks/llm/recipes/lora_finetune_distributed.py index fdea3871c..18b736fbf 100755 --- a/benchmarks/llm/recipes/lora_finetune_distributed.py +++ b/benchmarks/llm/recipes/lora_finetune_distributed.py @@ -6,6 +6,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -18,24 +19,25 @@ from torch import nn from torch.distributed import destroy_process_group, init_process_group - +from torch.distributed.fsdp import ( + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset -from torchtune.modules.peft import ( - DoRALinear, +from torchtune.modules.peft.peft_utils import ( get_adapter_params, get_lora_module_names, get_merged_lora_ckpt, - load_dora_magnitudes, - LoRALinear, set_trainable_params, - validate_missing_and_unexpected_for_lora, + validate_state_dict_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.utils import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -48,11 +50,8 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): distributed training and can be run on a single node (1 to 8 GPUs). Features: - - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is - done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config - ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). - DDP is currently not supported. Training on CPU is not supported. + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not + supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -110,14 +109,14 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." ) - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this @@ -133,13 +132,12 @@ def __init__(self, cfg: DictConfig) -> None: # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps @@ -159,7 +157,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: # and recipe state to be present. The keys should match up with what ``save_checkpoint`` # used to create these intermediate checkpoints if self._resume_from_checkpoint: - if training.ADAPTER_KEY not in checkpoint_dict: + if utils.ADAPTER_KEY not in checkpoint_dict: raise ValueError( "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." ) @@ -173,28 +171,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: + if self.seed != ckpt_dict[utils.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" ) ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -220,16 +218,13 @@ def setup(self, cfg: DictConfig) -> None: self._metric_logger.log_config(cfg) checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - self._compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), - base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], lora_weights_state_dict=( - checkpoint_dict[training.ADAPTER_KEY] + checkpoint_dict[utils.ADAPTER_KEY] if self._resume_from_checkpoint else None ), @@ -238,25 +233,13 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=( - checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None - ), + opt_state_dict=checkpoint_dict[utils.OPT_KEY] + if self._resume_from_checkpoint + else None, ) - # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._compile: - training.compile_loss(self._loss_fn, verbose=self._is_rank_zero) - - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._is_rank_zero: - log.info("Loss is initialized.") - # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup self._sampler, self._dataloader = self._setup_data( @@ -294,20 +277,14 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None + self, cfg_profiler: DictConfig ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. + cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) Returns: profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods @@ -346,42 +323,38 @@ def _setup_profiler( # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" else: assert ( cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" profiler, profiler_cfg = config.instantiate(cfg_profiler) if self._is_rank_zero: log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - return profiler def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - fsdp_cpu_offload: bool, - reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we initialize the model on meta device with - the right dtype - b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since - full state dicts are loaded with ``torch.load(mmap=True)`` - c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. """ self._lora_rank = cfg_model.lora_rank @@ -391,110 +364,87 @@ def _setup_model( self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) if self._is_rank_zero: - log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." - ) + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") init_start = time.perf_counter() - with training.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) - - self.adapter_params = get_adapter_params(model) - set_trainable_params(model, self.adapter_params) - - if self._compile: - training.compile_model(model, verbose=self._is_rank_zero) + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) - if enable_activation_checkpointing: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" ) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_name(name: str, module: nn.Module) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - name_list = name.split(".") - return ( - len(name_list) == 2 - and name_list[0] == "layers" - and str.isdigit(name_list[1]) + # The model contains LoRA params which won't have any matching keys in + # the state dict. As a result, we need to load with strict=False. + # Before loading the state dict, ensure the state dict keys for the base + # model and adapters (if available) match the keys in the full LoRA model + # This is a good sanity check to prevent silent errors + validate_state_dict_for_lora( + lora_attn_modules=cfg_model.lora_attn_modules, + apply_lora_to_mlp=cfg_model.apply_lora_to_mlp, + apply_lora_to_output=getattr(cfg_model, "apply_lora_to_output", False), + full_model_state_dict_keys=model.state_dict().keys(), + lora_state_dict_keys=( + lora_weights_state_dict.keys() + if lora_weights_state_dict is not None + else None + ), + base_model_state_dict_keys=base_model_state_dict.keys(), ) - training.shard_model( - model=model, - shard_conditions=[_is_layer_name], - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - ) + # Load both the base model weights and (if available) the adapter weights. Both + # of this should happen only on Rank 0 + model.load_state_dict(base_model_state_dict, strict=False) + if lora_weights_state_dict: + model.load_state_dict(lora_weights_state_dict, strict=False) - if lora_weights_state_dict: - lora_missing, lora_unexpected = training.load_from_full_model_state_dict( - model, - lora_weights_state_dict, - self._device, - self._is_rank_zero, - cpu_offload=fsdp_cpu_offload, - ) else: - lora_missing, lora_unexpected = None, None + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) - # Initialize LoRA params and RoPE buffers - with training.set_default_dtype(self._dtype), self._device: - lora_device = "cpu" if fsdp_cpu_offload else self._device - for m in model.modules(): - if ( - isinstance(m, LoRALinear) or isinstance(m, DoRALinear) - ) and not lora_weights_state_dict: - # lora may not be covered in state dict - # if finetune for the 1st time - m.lora_a.to_empty(device=lora_device) - m.lora_b.to_empty(device=lora_device) - m.initialize_parameters() - # RoPE is not covered in state dict - if hasattr(m, "rope_init"): - m.rope_init() - - base_missing, base_unexpected = training.load_from_full_model_state_dict( - model, - base_model_state_dict, - self._device, - self._is_rank_zero, - cpu_offload=fsdp_cpu_offload, - ) - is_dora = False - for m in model.modules(): - if hasattr(m, "initialize_dora_magnitude"): - is_dora = True - m.initialize_dora_magnitude() - if is_dora: - load_dora_magnitudes(model) - validate_missing_and_unexpected_for_lora( - lora_attn_modules=self._lora_attn_modules, - apply_lora_to_mlp=self._apply_lora_to_mlp, - apply_lora_to_output=self._apply_lora_to_output, - base_missing=base_missing, - base_unexpected=base_unexpected, - lora_missing=lora_missing, - lora_unexpected=lora_unexpected, + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # LoRA hyper-params needed for merging weights while saving checkpoints + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + + # Note: this needs to be set before wrapping with FSDP + self.adapter_params = get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + model = FSDP( + module=model, + auto_wrap_policy=utils.lora_fsdp_wrap_policy( + modules_to_wrap={modules.TransformerDecoderLayer} + ), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if not self._is_rank_zero + else None + ), ) + # Ensure no params and buffers are on meta device - training.validate_no_params_on_meta_device(model) + utils.validate_no_params_on_meta_device(model) - if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} ) - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -506,14 +456,15 @@ def _setup_optimizer( ) -> Optimizer: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + # Note: technically we should check _contains_fsdp for + # just the state dict of the adapter cfg, but should be equivalent + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict ) + optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: - log.info("Optimizer is initialized.") + log.info("Optimizer and loss are initialized.") return optimizer def _setup_lr_scheduler( @@ -543,7 +494,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -565,14 +516,12 @@ def _setup_data( batch_size=batch_size, sampler=sampler, collate_fn=partial( - padded_collate_sft, + utils.padded_collate, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else None, ) if self._is_rank_zero: @@ -590,7 +539,6 @@ def save_checkpoint( - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete - - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights Checkpointer will save the merged weights, adapter weights and recipe state in different checkpoint files. To correctly resume from training, the adapter weights @@ -602,20 +550,17 @@ def save_checkpoint( intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( + with FSDP.state_dict_type( self._model, - self._is_rank_zero, - device=self._device, - ) - - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, - device=self._device, - ) - else: - opt_state_dict = None + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + if intermediate_checkpoint: + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + else: + opt_state_dict = None # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file @@ -627,7 +572,7 @@ def save_checkpoint( adapter_state_dict = { k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) } - checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) # merge the adapter weights and base weights to create the model checkpoint merged_state_dict = get_merged_lora_ckpt( @@ -635,18 +580,18 @@ def save_checkpoint( rank=self._lora_rank, alpha=self._lora_alpha, ) - checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict}) # if training is in-progress, checkpoint the optimizer state and recipe state # as well. if intermediate_checkpoint: checkpoint_dict.update( { - training.OPT_KEY: opt_state_dict, - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) @@ -660,12 +605,12 @@ def save_checkpoint( ), "peft_type": "LORA", } - checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) + checkpoint_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, - adapter_only=self._save_adapter_weights_only, ) def train(self) -> None: @@ -673,9 +618,9 @@ def train(self) -> None: The core training loop. """ # clean up before training begins - training.cleanup_before_training() + utils.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -702,15 +647,6 @@ def train(self) -> None: ): break - # Start tracking CUDA memory for active steps for just the first epoch - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -725,21 +661,14 @@ def train(self) -> None: input_pos = ( input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) # Compute loss loss = self._loss_fn(logits, labels) - # free logits otherwise it peaks backward memory del logits @@ -760,7 +689,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -775,9 +704,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) + log_dict.update(utils.get_memory_stats(device=self._device)) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -788,18 +715,6 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() - # Stop tracking CUDA memory now that active steps are complete - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - # Step profiler # Note that this is called within gradient accumulation block, hence # will include multiple forward / backward passes if gradient accumulation > 1 @@ -818,8 +733,7 @@ def cleanup(self) -> None: def log_loss(self, loss): pass - -def prepare_voir(recipe:LoRAFinetuneRecipeDistributed): +def prepare_voir(recipe): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor @@ -853,15 +767,12 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not training.is_distributed(): + if not utils.is_distributed(): raise RuntimeError( "Distributed finetune recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) - if cfg.get("fsdp_cpu_offload", False): - # Utilize all available CPU cores for intra-op parallelism. This provides ~2x - # speed up when benchmarking fused AdamW on CPU - training.set_torch_num_threads() + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") config.log_config(recipe_name="LoRAFinetuneRecipeDistributed", cfg=cfg) diff --git a/benchmarks/llm/recipes/lora_finetune_single_device.py b/benchmarks/llm/recipes/lora_finetune_single_device.py index f08793f52..cf5256ead 100755 --- a/benchmarks/llm/recipes/lora_finetune_single_device.py +++ b/benchmarks/llm/recipes/lora_finetune_single_device.py @@ -6,6 +6,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -14,25 +15,22 @@ from warnings import warn import torch -import torchtune.modules.common_utils as common_utils from omegaconf import DictConfig, ListConfig from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset -from torchtune.modules.peft import ( +from torchtune.modules.peft.peft_utils import ( get_adapter_params, get_lora_module_names, get_merged_lora_ckpt, - load_dora_magnitudes, set_trainable_params, validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.utils import DummyProfiler, PROFILER_KEY from tqdm import tqdm @@ -90,10 +88,6 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): - Logging. Terminal, Disk, WandB and TensorBoard are all supported. - - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, - ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set - ``clip_grad_norm='inf'``. - For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config has example commands for how to kick-off training. @@ -110,7 +104,7 @@ def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) # Reduced precision logic - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) # fp16 precision is explicitly disabled as it is not supported in this # recipe (for example, no gradient scaling). if self._dtype == torch.float16: @@ -131,15 +125,14 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint - self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._clip_grad_norm = cfg.get("clip_grad_norm", None) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -154,7 +147,7 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: checkpoint_dict = self._checkpointer.load_checkpoint() if self._resume_from_checkpoint: - if training.ADAPTER_KEY not in checkpoint_dict: + if utils.ADAPTER_KEY not in checkpoint_dict: raise ValueError( "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." ) @@ -168,28 +161,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.epochs_run = ckpt_dict.get(utils.EPOCHS_KEY, 0) # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: + if self.seed != ckpt_dict.get(utils.SEED_KEY, self.seed): warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict.get(utils.SEED_KEY, 0)}" ) ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + self.seed = ckpt_dict.get(utils.SEED_KEY, self.seed) + if self.max_steps_per_epoch != ckpt_dict.get(utils.MAX_STEPS_KEY, self.max_steps_per_epoch): warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict.get(utils.MAX_STEPS_KEY, self.max_steps_per_epoch) # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict.get(utils.TOTAL_EPOCHS_KEY, self.total_epochs): warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -213,21 +206,16 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - self._compile = cfg.compile + self._model_compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - # hack to toggle to the low cpu ram version of the reparametrize_as_dtype - # hook based on the config. - common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) - - # set up model self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=cfg.compile, - base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + base_model_state_dict=checkpoint_dict[utils.MODEL_KEY], lora_weights_state_dict=( - checkpoint_dict[training.ADAPTER_KEY] + checkpoint_dict[utils.ADAPTER_KEY] if self._resume_from_checkpoint else None ), @@ -239,21 +227,11 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, opt_state_dict=( - checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None + checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None ), ) - # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._compile: - self._loss_fn = training.compile_loss(self._loss_fn) - - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - log.info("Loss is initialized.") # Dataloader depends on the tokenizer and loss_fn and should be @@ -293,20 +271,14 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None + self, cfg_profiler: DictConfig ) -> Union[torch.profiler.profile, DummyProfiler]: """ Parses the `profiler` section of top-level `cfg` and sets up profiler Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. + cfg_profiler: DictConfig - `profiler` section of the top-level `cfg` (the main config passed to `recipe.main`) Returns: profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods @@ -346,23 +318,17 @@ def _setup_profiler( # Check that component is included and set correctly if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + cfg_profiler["_component_"] = "torchtune.utils.setup_torch_profiler" else: assert ( cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + == "torchtune.utils.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.utils.setup_torch_profiler`" profiler, profiler_cfg = config.instantiate(cfg_profiler) log.info(f" Profiler config after instantiation: {profiler_cfg}") - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - return profiler def _setup_model( @@ -373,7 +339,7 @@ def _setup_model( base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: - with training.set_default_dtype(self._dtype), self._device: + with utils.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) self._lora_rank = cfg_model.lora_rank @@ -382,30 +348,23 @@ def _setup_model( self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) self.adapter_params = get_adapter_params(model) - self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) set_trainable_params(model, self.adapter_params) - if compile_model: - training.compile_model(model) - if enable_activation_checkpointing: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} ) base_missing, base_unexpected = model.load_state_dict( base_model_state_dict, strict=False ) - # This is for any adapters that need to be initialized after base weights - # have been loaded (e.g. DoRA). - if self._is_dora: - load_dora_magnitudes(model) if lora_weights_state_dict: lora_missing, lora_unexpected = model.load_state_dict( lora_weights_state_dict, strict=False ) else: lora_missing, lora_unexpected = None, None + validate_missing_and_unexpected_for_lora( lora_attn_modules=self._lora_attn_modules, apply_lora_to_mlp=self._apply_lora_to_mlp, @@ -417,16 +376,21 @@ def _setup_model( ) # Validate model adapter params were loaded in with the expected dtype # TODO (rohan-varma): Further validation to ensure the appropriate base params + # are NF4 vs bf16 based on the quantization config. - training.validate_expected_param_dtype( + utils.validate_expected_param_dtype( self.adapter_params.items(), dtype=self._dtype ) log.info(f"Model is initialized with precision {self._dtype}.") - + # Compile model, if enabled. + if compile_model: + log.info("Compiling model with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + model.compile(backend=backend) if self._device.type == "cuda": - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) return model def _setup_optimizer( @@ -468,13 +432,13 @@ def _setup_data( """ if isinstance(cfg_dataset, ListConfig): datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) for single_cfg_dataset in cfg_dataset ] ds = ConcatDataset(datasets=datasets) packed = False else: - ds = config.instantiate(cfg_dataset, self._tokenizer) + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( @@ -489,14 +453,12 @@ def _setup_data( sampler=sampler, batch_size=batch_size, collate_fn=partial( - padded_collate_sft, + utils.padded_collate, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else None, ) log.info("Dataset and Sampler are initialized.") @@ -510,45 +472,41 @@ def save_checkpoint(self, epoch: int) -> None: - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete - - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + Checkpointer will save the merged weights, adapter weights and recipe state in + different checkpoint files. To correctly resume from training, the adapter weights + and recipe state must be provided along with the base model weights. """ ckpt_dict = {} - - intermediate_checkpoint = epoch + 1 < self.total_epochs # if training is in-progress, checkpoint the optimizer state as well - if intermediate_checkpoint: + if epoch + 1 < self.total_epochs: ckpt_dict.update( { - training.OPT_KEY: self._optimizer.state_dict(), - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, + utils.OPT_KEY: self._optimizer.state_dict(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) # Move to CPU to avoid a copy on GPU state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} - # Construct the adapter weights - # Do this using the state_dict to avoid running upcast and H2D in state_dict post hook twice - # Must be before get_merged_lora_ckpt because get_merged_lora_ckpt will remove lora keys - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in state_dict.items() if adapter_key_filter(k) - } - # Construct the full state dict with LoRA weights merged into base LLM weights merged_state_dict = get_merged_lora_ckpt( state_dict, rank=self._lora_rank, alpha=self._lora_alpha, ) - ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) + ckpt_dict.update({utils.MODEL_KEY: merged_state_dict}) - ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + # Construct the adapter weights + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) + } + ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) adapter_config = { "r": self._lora_rank, "lora_alpha": self._lora_alpha, @@ -559,51 +517,19 @@ def save_checkpoint(self, epoch: int) -> None: ), "peft_type": "LORA", } - ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) - + ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config}) self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, - adapter_only=self._save_adapter_weights_only, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), ) - def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - # run model - logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - - # Compute loss - loss = self._loss_fn(logits, labels) - - # free logits otherwise it peaks backward memory - del logits - - return loss - def train(self) -> None: """ The core training loop. """ - if self._compile: + if self._model_compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." ) @@ -629,29 +555,33 @@ def train(self) -> None: ): break - # Start tracking CUDA memory for active steps for just the first epoch - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() - - batch = {k: v.to(self._device) for k, v in batch.items()} - num_tokens += batch["tokens"].numel() + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) - loss = self._loss_step(batch) - loss = loss / self._gradient_accumulation_steps + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) / self._gradient_accumulation_steps running_loss += loss loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() @@ -662,7 +592,7 @@ def train(self) -> None: self.log_loss(loss_to_log) pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -678,10 +608,8 @@ def train(self) -> None: and self._log_peak_memory_stats ): log_dict.update( - training.get_memory_stats(device=self._device) + utils.get_memory_stats(device=self._device) ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -692,31 +620,13 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() - # Stop tracking CUDA memory now that active steps are complete - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - # Step the profiler # Note we are stepping each batch, which might not include optimizer step in the trace # if the schedule cycle doesn't align with gradient accumulation. prof.step() self.epochs_run += 1 - start_save_checkpoint = time.perf_counter() - log.info("Starting checkpoint save...") self.save_checkpoint(epoch=curr_epoch) - log.info( - "Checkpoint saved in {:.2f} seconds.".format( - time.perf_counter() - start_save_checkpoint - ) - ) def cleanup(self) -> None: self._metric_logger.close() @@ -725,7 +635,8 @@ def log_loss(self, loss): pass -def prepare_voir(recipe:LoRAFinetuneRecipeSingleDevice): + +def prepare_voir(recipe): from benchmate.observer import BenchObserver from benchmate.monitor import bench_monitor diff --git a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py index bdd63e8cd..8ee77c06a 100644 --- a/benchmarks/llm/recipes/ppo_full_finetune_single_device.py +++ b/benchmarks/llm/recipes/ppo_full_finetune_single_device.py @@ -17,8 +17,7 @@ from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset from torchtune.modules import rlhf from torchtune.modules.rlhf import PPOStats, Trajectory @@ -107,7 +106,7 @@ class PPOFullFinetuneRecipeSingleDevice(FTRecipeInterface): def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) # Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor # enabled necessary features such as gradient scaling. @@ -123,7 +122,7 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) # manually setting up a generator for the recipe self._rng = torch.Generator(self._device).manual_seed(self.seed) self._total_steps = 0 @@ -178,15 +177,15 @@ def setup(self, cfg: DictConfig) -> None: self._value_model, self._reward_model, self._ref_policy_model, - ) = self._setup_models( + ) = self._setup_model( cfg_model=cfg.policy_model, cfg_reward_value_model=cfg.reward_and_value_model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, compile_model=self._model_compile, - policy_state_dict=policy_model_checkpoint_dict[training.MODEL_KEY], - ref_policy_state_dict=ref_policy_state_dict[training.MODEL_KEY], - value_model_state_dict=value_model_checkpoint_dict[training.MODEL_KEY], - reward_model_state_dict=reward_model_state_dict[training.MODEL_KEY], + policy_state_dict=policy_model_checkpoint_dict[utils.MODEL_KEY], + ref_policy_state_dict=ref_policy_state_dict[utils.MODEL_KEY], + value_model_state_dict=value_model_checkpoint_dict[utils.MODEL_KEY], + reward_model_state_dict=reward_model_state_dict[utils.MODEL_KEY], ) # setup tokenizer @@ -199,7 +198,7 @@ def setup(self, cfg: DictConfig) -> None: cfg_optimizer=cfg.optimizer, optimizer_in_bwd=cfg.optimizer_in_bwd, opt_state_dict=( - policy_model_checkpoint_dict[training.OPT_KEY] + policy_model_checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None ), @@ -349,10 +348,7 @@ def _setup_checkpointers( value_cfg: DictConfig, reward_cfg: DictConfig, ) -> Tuple[ - training.Checkpointer, - training.Checkpointer, - training.Checkpointer, - training.Checkpointer, + utils.Checkpointer, utils.Checkpointer, utils.Checkpointer, utils.Checkpointer ]: """ Sets up checkpointers for policy, reference policy, value, and reward models. @@ -398,7 +394,7 @@ def _setup_checkpointers( reward_checkpointer, ) - def _setup_models( + def _setup_model( self, cfg_model: DictConfig, cfg_reward_value_model: DictConfig, @@ -413,49 +409,53 @@ def _setup_models( Sets up the policy model, reference policy model, reward model, and value model. """ - with training.set_default_dtype(self._dtype), self._device: + with utils.set_default_dtype(self._dtype), self._device: policy_model = config.instantiate(cfg_model) ref_policy_model = config.instantiate(cfg_model) reward_model = config.instantiate(cfg_reward_value_model) value_model = config.instantiate(cfg_reward_value_model) if enable_activation_checkpointing: - training.set_activation_checkpointing( - policy_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + utils.set_activation_checkpointing( + policy_model, auto_wrap_policy={modules.TransformerDecoderLayer} ) - training.set_activation_checkpointing( - value_model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + utils.set_activation_checkpointing( + value_model, auto_wrap_policy={modules.TransformerDecoderLayer} ) policy_model.load_state_dict(policy_state_dict) ref_policy_model.load_state_dict(ref_policy_state_dict) - # since we should be loading a classifier checkpoint into - # a classifier model, this function should just ensure - # output.weight appears in the state_dict and the model's parameters, - # and removes output.bias from the state dict if found - training.update_state_dict_for_classifier( - reward_model_state_dict, reward_model.named_parameters() + reward_missing, reward_unexpected = reward_model.load_state_dict( + reward_model_state_dict, strict=False ) - reward_model.load_state_dict(reward_model_state_dict) - - # same as above - training.update_state_dict_for_classifier( - value_model_state_dict, value_model.named_parameters() + value_missing, value_unexpected = value_model.load_state_dict( + value_model_state_dict, strict=False ) - value_model.load_state_dict(value_model_state_dict) + + # some extra validation for HF classifier checkpoints with a `score.bias` present + assert ( + reward_missing == value_missing == [] + ), f"Missing keys in reward ({reward_missing}) and value model ({value_missing}) state dicts." + + if reward_unexpected or value_unexpected: + # the only unexpected keys should be when pre-trained HF models were saved with + # bias=True in final classification layers. This happens when training a reward model with TRL. + assert ( + reward_unexpected == value_unexpected == ["output.bias"] + ), f"Unexpected keys in reward ({reward_unexpected}) and value model ({value_unexpected}) state dicts." # Validate models were loaded in with the expected dtype. - training.validate_expected_param_dtype( + utils.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) - training.validate_expected_param_dtype( + utils.validate_expected_param_dtype( reward_model.named_parameters(), dtype=self._dtype ) - training.validate_expected_param_dtype( + utils.validate_expected_param_dtype( value_model.named_parameters(), dtype=self._dtype ) - training.validate_expected_param_dtype( + utils.validate_expected_param_dtype( ref_policy_model.named_parameters(), dtype=self._dtype ) @@ -497,8 +497,8 @@ def _setup_models( value_model.compile(backend=backend) if self._device.type == "cuda": - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) return policy_model, value_model, reward_model, ref_policy_model @@ -518,17 +518,17 @@ def _setup_optimizer( ) } # Register optimizer step hooks on the models to run optimizer in backward. - training.register_optim_in_bwd_hooks( + utils.register_optim_in_bwd_hooks( model=self._policy_model, optim_dict=optim_dict ) - training.register_optim_in_bwd_hooks( + utils.register_optim_in_bwd_hooks( model=self._value_model, optim_dict=optim_dict ) # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( model=self._policy_model, optim_dict=optim_dict ) - self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( + self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper( model=self._value_model, optim_dict=optim_dict ) # Load optimizer states. If optimizer states are being restored in an optimizer in backward @@ -582,9 +582,7 @@ def _setup_data( sampler=sampler, batch_size=batch_size, collate_fn=partial( - padded_collate, - pad_direction="left", - keys_to_pad=["tokens", "labels"], + rlhf.left_padded_collate, padding_idx=self._tokenizer.pad_id, ), drop_last=True, @@ -599,27 +597,25 @@ def save_checkpoint( Save state dict to file. The recipe save_checkpoint method is responsible for correctly creating the checkpoint dict and passing to the checkpointer. """ - policy_ckpt_dict = {training.MODEL_KEY: self._policy_model.state_dict()} - value_ckpt_dict = {training.MODEL_KEY: self._value_model.state_dict()} + policy_ckpt_dict = {utils.MODEL_KEY: self._policy_model.state_dict()} + value_ckpt_dict = {utils.MODEL_KEY: self._value_model.state_dict()} # if training is in-progress, checkpoint the optimizer state and rng state as well if is_intermediate_checkpoint: policy_ckpt_dict.update( { - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self._epochs_run, - training.TOTAL_EPOCHS_KEY: self._total_epochs, - training.MAX_STEPS_KEY: self._total_steps, - training.STEPS_KEY: self._steps_run, - training.RNG_KEY: self._rng.get_state(), + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self._epochs_run, + utils.TOTAL_EPOCHS_KEY: self._total_epochs, + utils.MAX_STEPS_KEY: self._total_steps, + utils.STEPS_KEY: self._steps_run, + utils.RNG_KEY: self._rng.get_state(), } ) if not self._optimizer_in_bwd: - policy_ckpt_dict[training.OPT_KEY] = self._optimizer.state_dict() + policy_ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict() else: - policy_ckpt_dict[ - training.OPT_KEY - ] = self._optim_ckpt_wrapper.state_dict() + policy_ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict() self._policy_checkpointer.save_checkpoint( policy_ckpt_dict, @@ -641,20 +637,20 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: # warn the user and overwrite. try: if ( - self.seed != ckpt_dict[training.SEED_KEY] - or self._total_steps != ckpt_dict[training.MAX_STEPS_KEY] - or self._total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY] + self.seed != ckpt_dict[utils.SEED_KEY] + or self._total_steps != ckpt_dict[utils.MAX_STEPS_KEY] + or self._total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY] ): warn( message="""Configured value for seed, total_steps, or total_epochs does not match the value stored in checkpoint.""" ) - self.seed = training.set_seed(seed=ckpt_dict[training.SEED_KEY]) - self._rng.set_state(ckpt_dict[training.RNG_KEY]) - self._steps_run = ckpt_dict[training.STEPS_KEY] - self._total_steps = ckpt_dict[training.MAX_STEPS_KEY] - self._total_epochs = ckpt_dict[training.TOTAL_EPOCHS_KEY] - self._epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY]) + self._rng.set_state(ckpt_dict[utils.RNG_KEY]) + self._steps_run = ckpt_dict[utils.STEPS_KEY] + self._total_steps = ckpt_dict[utils.MAX_STEPS_KEY] + self._total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY] + self._epochs_run = ckpt_dict[utils.EPOCHS_KEY] except KeyError as e: raise KeyError from e( @@ -744,7 +740,7 @@ def generate_trajectory(self, input_ids: torch.Tensor) -> Trajectory: # step 5.1 the scores from the reward model are the logits for the last non-padding token in # each (query, truncated-response) pair - seq_lens = training.get_unmasked_sequence_lengths(response_padding_masks) + seq_lens = utils.get_unmasked_sequence_lengths(response_padding_masks) scores = scores[torch.arange(batch_size), seq_lens + context_length].squeeze(-1) # step 5.2 if configured, apply any penalties for sequences without EOS tokens @@ -832,7 +828,7 @@ def train(self) -> None: self._sampler.set_epoch(curr_epoch) for _, batch in enumerate(self._dataloader): - batch = batch["tokens"].to(self._device) + batch = batch.to(self._device) _, context_length = batch.shape # step 1. generate the trajectory using: @@ -1036,7 +1032,7 @@ def log_metrics( "response_lengths": trajectory.seq_lens.float().mean(), } if self._device.type == "cuda" and self._log_peak_memory_stats: - log_dict.update(training.get_memory_stats(device=self._device)) + log_dict.update(utils.get_memory_stats(device=self._device)) self._metric_logger.log_dict(log_dict, step=self.global_step) diff --git a/benchmarks/llm/recipes/qat_distributed.py b/benchmarks/llm/recipes/qat_distributed.py index 578669ed8..211433835 100755 --- a/benchmarks/llm/recipes/qat_distributed.py +++ b/benchmarks/llm/recipes/qat_distributed.py @@ -6,31 +6,36 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from warnings import warn import torch from omegaconf import DictConfig, ListConfig from torch import nn -from torch.distributed import destroy_process_group, init_process_group - +from torch.distributed import init_process_group +from torch.distributed.fsdp import ( + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_packed, padded_collate_sft + +from torchtune import config, modules, utils from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY -from torchtune.training.activations import apply_selective_activation_checkpointing +from torchtune.utils.activations import apply_selective_activation_checkpointing from tqdm import tqdm + log = utils.get_logger("DEBUG") @@ -51,11 +56,8 @@ class QATRecipeDistributed(FTRecipeInterface): weight and activation values to stabilize before fake quantizing them, potentially leading to improved quantized accuracy. This can be specified through ``fake_quant_after_n_steps``. - - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states - is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is - done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config - ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). - DDP is currently not supported. Training on CPU is not supported. + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU + is not supported. - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep @@ -103,12 +105,12 @@ class QATRecipeDistributed(FTRecipeInterface): Raises: ValueError: If ``dtype`` is set to fp16. - RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. """ def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) - self._dtype = training.get_dtype(cfg.dtype, device=self._device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) if self._dtype == torch.float16: raise ValueError( @@ -117,7 +119,7 @@ def __init__(self, cfg: DictConfig) -> None: if ( cfg.get("fsdp_cpu_offload", False) - and cfg.optimizer.get("fused", False) + and cfg.get("fused", False) and not utils.torch_version_ge("2.4.0") ): raise RuntimeError( @@ -131,21 +133,18 @@ def __init__(self, cfg: DictConfig) -> None: # _is_rank_zero is used primarily for logging. In the future, the logger # should directly take care of this - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() self._is_rank_zero = rank == 0 # Training cfg self._resume_from_checkpoint = cfg.resume_from_checkpoint self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ - cfg.get("fsdp_sharding_strategy", "FULL_SHARD") - ] self._fake_quant_after_n_steps = cfg.get("fake_quant_after_n_steps", None) self._quantizer_mode = None # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) + self.seed = utils.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch @@ -171,28 +170,28 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: Updates the recipe state from checkpoint. """ try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: + if self.seed != ckpt_dict[utils.SEED_KEY]: warn( message=( "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" ) ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: warn( message=( "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" ) ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: warn( message=( "Config value for total_epochs does not match the checkpoint value, " @@ -208,8 +207,8 @@ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: def setup(self, cfg: DictConfig) -> None: """ - Setup the recipe. This includes training state (if resume_from_checkpoint is True), - model, tokenizer, loss, optimizer, sampler, and dataloader. + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. """ if self._is_rank_zero: self._metric_logger = config.instantiate(cfg.metric_logger) @@ -217,48 +216,34 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + ckpt_dict = self.load_checkpoint(cfg.checkpointer) - self._model_compile = cfg.get("compile", False) + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, - custom_sharded_layers=cfg.get("custom_sharded_layers", None), + memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), - reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), - model_state_dict=checkpoint_dict[training.MODEL_KEY], + model_state_dict=ckpt_dict[utils.MODEL_KEY], ac_mode=cfg.get("ac_mode", None), ac_option=cfg.get("ac_option", None), quantizer_cfg=cfg.get("quantizer", None), ) + self._tokenizer = config.instantiate(cfg.tokenizer) + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - opt_state_dict=checkpoint_dict[training.OPT_KEY] + opt_state_dict=ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None, ) - # initialize loss self._loss_fn = config.instantiate(cfg.loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) - log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -285,89 +270,12 @@ def setup(self, cfg: DictConfig) -> None: self._steps_per_epoch = self.max_steps_per_epoch self.global_step = self.epochs_run * self._steps_per_epoch - # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) - # if cfg is missing profiler key or if `cfg.profiler.enabled = False` - self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None - ) -> Union[torch.profiler.profile, DummyProfiler]: - """ - Parses the `profiler` section of top-level `cfg` and sets up profiler - - Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. - - Returns: - profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods - for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such - that the instrumented training loop does not need to be changed profiling is disabled. - - The profiler config can be provided in configs under the `profiler` key with the following layout: - - .. code-block:: yaml - profiler: - enabled: bool - - #Output directory of trace artifacts - output_dir: str - - #`torch.profiler.ProfilerActivity` types to trace - cpu: bool - cuda: bool - - #Trace options - profile_memory: bool - with_stack: bool - record_shapes: bool - with_flops: bool - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: int - warmup_steps: int - active_steps: int - num_cycles: int - """ - # Missing profiler section in config, assume disabled - if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) - - # Check that component is included and set correctly - if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" - else: - assert ( - cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" - - profiler, profiler_cfg = config.instantiate(cfg_profiler) - - if self._is_rank_zero: - log.info(f" Profiler config after instantiation: {profiler_cfg}") - - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - - return profiler - def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, - custom_sharded_layers: Optional[List[str]], + memory_efficient_fsdp_wrap: bool, fsdp_cpu_offload: bool, - reshard_after_forward: bool, model_state_dict: Dict[str, Any], ac_mode: Optional[str] = None, ac_option: Optional[int] = None, @@ -375,20 +283,37 @@ def _setup_model( ) -> nn.Module: """ Model initialization has some important considerations: - a. To minimize GPU peak memory, we initialize the model on meta device with - the right dtype - b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since - full state dicts are loaded with ``torch.load(mmap=True)`` + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. """ - if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + log.info( - "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" ) - init_start = time.perf_counter() - with training.set_default_dtype(self._dtype), torch.device("meta"): - model = config.instantiate(cfg_model) + # Load both the model weights. This should happen only on Rank 0 + model.load_state_dict(model_state_dict) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) # We currently have two versions of activation checkpointing in this recipe # for testing and BC purposes. ``enable_activation_checkpointing`` controls @@ -396,6 +321,9 @@ def _setup_model( # ac_mode and ac_option together control selective AC. This is only enabled # when these are set AND ``enable_activation_checkpointing`` is set to False # We'll clean this up as soon as testing of AC is complete + ac_mode = ac_mode + ac_option = ac_option + if (not enable_activation_checkpointing) and (ac_mode is not None): apply_selective_activation_checkpointing( model, @@ -403,18 +331,12 @@ def _setup_model( ac_option, ) - # original activation checkpointing (full) - flip the condition above - if enable_activation_checkpointing and ac_mode is None: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} - ) - # Apply quantization-aware training during finetuning if quantizer_cfg is None: raise ValueError("Quantizer must be specified for QAT recipe.") quantizer = config.instantiate(quantizer_cfg) quantizer.precision = self._dtype - quantizer_mode = training.quantization.get_quantizer_mode(quantizer) + quantizer_mode = utils.quantization.get_quantizer_mode(quantizer) if "qat" not in quantizer_mode: raise ValueError( "Quantizer mode '%s' is not supported for finetuning" % quantizer_mode @@ -422,57 +344,43 @@ def _setup_model( self._quantizer_mode = quantizer_mode model = quantizer.prepare(model) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - fsdp_shard_conditions = [] - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_fqn(s: str) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - s_list = s.split(".") - return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) - - fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] - - # If wrapping any layers separately, we can add another shard condition - # A layer will be sharded if any of the fsdp_shard_conditions are met - if custom_sharded_layers: - fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] - - training.shard_model( - model=model, - shard_conditions=fsdp_shard_conditions, - cpu_offload=fsdp_cpu_offload, - reshard_after_forward=reshard_after_forward, - ) - - with training.set_default_dtype(self._dtype), self._device: - for m in model.modules(): - # RoPE is not covered in state dict - if hasattr(m, "rope_init"): - m.rope_init() - - # This method will convert the full model state dict into a sharded state - # dict and load into the model - training.load_from_full_model_state_dict( - model, model_state_dict, self._device, self._is_rank_zero, strict=True + # Wrap the model with FSDP. This will ensure that the model is sharded + # across all available GPUs. + model = FSDP( + module=model, + auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, + modules_to_wrap={modules.TransformerDecoderLayer}, + ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: module.to_empty( + device=torch.device("cuda"), recurse=False + ) + if not self._is_rank_zero + else None + ), ) # Ensure no params and buffers are on meta device - training.validate_no_params_on_meta_device(model) + utils.validate_no_params_on_meta_device(model) - if self._is_rank_zero: - log.info( - f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} ) - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) # synchronize before training begins torch.distributed.barrier() @@ -482,13 +390,17 @@ def _is_layer_fqn(s: str) -> bool: def _setup_optimizer( self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None ) -> Optimizer: + """ + Set up the optimizer. This method also handles transforing the state dict + for FSDP. + """ optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: - training.load_from_full_optimizer_state_dict( - optimizer, - opt_state_dict, - self._device, + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict ) + optimizer.load_state_dict(opt_state_dict) if self._is_rank_zero: log.info("Optimizer is initialized.") @@ -505,7 +417,7 @@ def _setup_data( DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, iterable datasets and streaming datasets are not supported. """ - world_size, rank = training.get_world_size_and_rank() + world_size, rank = utils.get_world_size_and_rank() if isinstance(cfg_dataset, ListConfig): datasets = [ @@ -519,21 +431,23 @@ def _setup_data( packed = cfg_dataset.get("packed", False) sampler = DistributedSampler( - ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0 + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, ) dataloader = DataLoader( dataset=ds, batch_size=batch_size, sampler=sampler, collate_fn=partial( - padded_collate_sft, + utils.padded_collate, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) if not packed - else partial( - padded_collate_packed, - ), + else None, ) if self._is_rank_zero: @@ -541,72 +455,57 @@ def _setup_data( return sampler, dataloader - def save_checkpoint( - self, - epoch: int, - ) -> None: + def save_checkpoint(self, epoch: int) -> None: """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Model weights with key training.MODEL_KEY - - Relevant recipe state if training is not complete - - Checkpointer will save the model weights and recipe state in - different checkpoint files. To correctly resume training from an intermediate checkpoint, - the model weights and recipe state must be provided. + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. """ - # final dict passed onto the checkpointer checkpoint_dict = {} - intermediate_checkpoint = epoch + 1 < self.total_epochs # To prevent GPU memory from spiking during checkpoint save, # we consolidate the full model and optim state dicts on CPU for rank 0 - cpu_state_dict = training.get_full_model_state_dict( + with FSDP.state_dict_type( self._model, - self._is_rank_zero, - ) - - if intermediate_checkpoint: - opt_state_dict = training.get_full_optimizer_state_dict( - self._optimizer, - self._is_rank_zero, - ) - else: - opt_state_dict = None + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) # Now that we have the model and opt state dict, create the actual checkpoint dict # to be sent to the checkpointer and ultimately written to file if self._is_rank_zero: - checkpoint_dict.update({training.MODEL_KEY: cpu_state_dict}) + checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) - # if training is in-progress, checkpoint the optimizer state and recipe state - # as well. - if intermediate_checkpoint: + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: checkpoint_dict.update( { - training.OPT_KEY: opt_state_dict, - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, } ) self._checkpointer.save_checkpoint( checkpoint_dict, epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), ) def train(self) -> None: """ - The core training loop. + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. """ # clean up before training begins - training.cleanup_before_training() + utils.cleanup_before_training() - _, rank = training.get_world_size_and_rank() + _, rank = utils.get_world_size_and_rank() # zero out the gradients before starting training self._optimizer.zero_grad() @@ -616,7 +515,6 @@ def train(self) -> None: running_loss = 0 num_tokens = 0 - self._profiler.start() # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -633,15 +531,6 @@ def train(self) -> None: ): break - # Start tracking CUDA memory for active steps for just the first epoch - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() - # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] # Get the attention mask and position ids from the dataset if they @@ -656,7 +545,7 @@ def train(self) -> None: "Step 0: Disabling fake quant, will re-enable in step %s" % self._fake_quant_after_n_steps ) - disable_fq = training.quantization._get_disable_fake_quant( + disable_fq = utils.quantization._get_disable_fake_quant( self._quantizer_mode ) self._model.apply(disable_fq) @@ -665,7 +554,7 @@ def train(self) -> None: "Step %s: Enabling fake quant" % self._fake_quant_after_n_steps ) - enable_fq = training.quantization._get_enable_fake_quant( + enable_fq = utils.quantization._get_enable_fake_quant( self._quantizer_mode ) self._model.apply(enable_fq) @@ -679,21 +568,12 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) # Compute loss loss = self._loss_fn(logits, labels) - # free logits otherwise it peaks backward memory - del logits loss = loss / self._gradient_accumulation_steps running_loss += loss @@ -710,7 +590,7 @@ def train(self) -> None: loss_to_log = running_loss.item() pbar.update(1) pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" ) # Log per-step metrics @@ -725,9 +605,7 @@ def train(self) -> None: "tokens_per_second_per_gpu": num_tokens / time_per_step, } if self._log_peak_memory_stats: - log_dict.update( - training.get_memory_stats(device=self._device) - ) + log_dict.update(utils.get_memory_stats(device=self._device)) self._metric_logger.log_dict( log_dict, step=self.global_step, @@ -738,32 +616,13 @@ def train(self) -> None: num_tokens = 0 t0 = time.perf_counter() - # Stop tracking CUDA memory now that active steps are complete - if ( - self._is_rank_zero - and curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - - # Step profiler - # Note that this is called within gradient accumulation block, hence - # will include multiple forward / backward passes if gradient accumulation > 1 - self._profiler.step() - self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) - self._profiler.stop() - def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() - destroy_process_group() + torch.distributed.destroy_process_group() @config.parse @@ -775,16 +634,17 @@ def recipe_main(cfg: DictConfig) -> None: - Parameters specified in config (see available configs through ``tune ls``) - Overwritten by arguments from the command-line """ - if not training.is_distributed(): + if not utils.is_distributed(): raise RuntimeError( "Distributed QAT recipe should be run via a distributed launcher." "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" ) + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") if cfg.get("fsdp_cpu_offload", False): # Utilize all available CPU cores for intra-op parallelism. This provides ~2x # speed up when benchmarking fused AdamW on CPU - training.set_torch_num_threads() + utils.set_torch_num_threads() config.log_config(recipe_name="QATRecipeDistributed", cfg=cfg) diff --git a/benchmate/benchmate/monitor.py b/benchmate/benchmate/monitor.py index a7f1dd0f3..0ad34a3d3 100644 --- a/benchmate/benchmate/monitor.py +++ b/benchmate/benchmate/monitor.py @@ -126,6 +126,7 @@ def monogpu_monitor(*args, **kwargs): yield log + @contextmanager def bench_monitor(*args, **kwargs): if int(os.getenv("RANK", -1)) == -1: From 6d1e114000cc4200ea307330032234db6696e40d Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Mon, 30 Sep 2024 14:39:43 -0400 Subject: [PATCH 6/7] Revert "Fix rlhf on trl v0.11.0" This reverts commit 6caac29e4c75e9fff050969194672e77db1187bd. --- benchmarks/rlhf/main.py | 4 ++-- benchmarks/rlhf/prepare.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/benchmarks/rlhf/main.py b/benchmarks/rlhf/main.py index a46b9579b..0be12d282 100755 --- a/benchmarks/rlhf/main.py +++ b/benchmarks/rlhf/main.py @@ -13,7 +13,7 @@ from trl import ModelConfig from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE class PPOv2TrainerIntrumented(PPOv2Trainer): @@ -62,7 +62,7 @@ def main(): ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1 ) diff --git a/benchmarks/rlhf/prepare.py b/benchmarks/rlhf/prepare.py index 5e4cb4eba..4c9aa631f 100755 --- a/benchmarks/rlhf/prepare.py +++ b/benchmarks/rlhf/prepare.py @@ -11,7 +11,7 @@ from datasets import load_dataset from trl import ModelConfig from trl.trainer.ppov2_trainer import PPOv2Config -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE if __name__ == "__main__": @@ -30,7 +30,7 @@ tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( config.reward_model_path, From 93015e5748bf6ce3749aacec7d53d32db8730b93 Mon Sep 17 00:00:00 2001 From: "pierre.delaunay" Date: Wed, 2 Oct 2024 12:58:16 -0400 Subject: [PATCH 7/7] Add latex output as an option --- .pin/constraints-cuda-torch.txt | 77 ++++++++++--------- benchmarks/brax/requirements.cuda.txt | 16 ++-- benchmarks/diffusion/requirements.cuda.txt | 25 +++--- benchmarks/dinov2/requirements.cuda.txt | 10 +-- benchmarks/flops/requirements.cuda.txt | 10 +-- benchmarks/geo_gnn/requirements-pre.cuda.txt | 7 +- benchmarks/geo_gnn/requirements.cuda.txt | 22 +++--- benchmarks/huggingface/requirements.cuda.txt | 13 ++-- benchmarks/lightning/requirements.cuda.txt | 16 ++-- benchmarks/llama/requirements.cuda.txt | 28 +++---- benchmarks/llava/requirements.cuda.txt | 25 +++--- benchmarks/llm/requirements.cuda.txt | 37 +++++---- benchmarks/llm/requirements.in | 2 +- benchmarks/purejaxrl/requirements.cuda.txt | 32 ++++---- benchmarks/recursiongfn/requirements.cuda.txt | 32 ++++---- benchmarks/rlhf/requirements.cuda.txt | 28 +++---- benchmarks/timm/requirements.cuda.txt | 12 +-- benchmarks/torchatari/requirements.cuda.txt | 16 ++-- benchmarks/torchvision/requirements.cuda.txt | 10 +-- .../torchvision_ddp/requirements.cuda.txt | 10 +-- benchmarks/vjepa/requirements.cuda.txt | 18 ++--- constraints/cuda.txt | 11 +++ milabench/_version.py | 6 +- milabench/report.py | 34 +++++++- scripts/article/run_cuda.sh | 17 +++- 25 files changed, 291 insertions(+), 223 deletions(-) diff --git a/.pin/constraints-cuda-torch.txt b/.pin/constraints-cuda-torch.txt index 96876ac75..1886ffa21 100644 --- a/.pin/constraints-cuda-torch.txt +++ b/.pin/constraints-cuda-torch.txt @@ -32,9 +32,9 @@ accelerate==0.34.2 # -r benchmarks/rlhf/requirements.in # diffusers # trl -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # datasets # fsspec @@ -50,7 +50,7 @@ argklass==1.4.4 # -r benchmarks/diffusion/requirements.in # -r benchmarks/llm/requirements.in # -r benchmarks/purejaxrl/requirements.in -astroid==3.2.4 +astroid==3.3.4 # via pylint asttokens==2.4.1 # via giving @@ -58,7 +58,7 @@ async-timeout==4.0.3 # via aiohttp attrs==24.2.0 # via aiohttp -beartype==0.18.5 +beartype==0.19.0 # via -r benchmarks/vjepa/requirements.in black==24.8.0 # via navix @@ -88,7 +88,7 @@ certifi==2024.8.30 # sentry-sdk charset-normalizer==3.3.2 # via requests -chex==0.1.86 +chex==0.1.87 # via # distrax # evosax @@ -117,7 +117,7 @@ cvxopt==1.3.2 # via -r benchmarks/recursiongfn/requirements.in cycler==0.12.1 # via matplotlib -datasets==3.0.0 +datasets==3.0.1 # via # -r benchmarks/diffusion/requirements.in # -r benchmarks/llama/requirements.in @@ -188,7 +188,7 @@ filelock==3.16.1 # torch # transformers # triton -fire==0.6.0 +fire==0.7.0 # via # -r benchmarks/llama/requirements.in # -r benchmarks/llm/requirements.txt @@ -210,7 +210,7 @@ flax==0.9.0 # flashbax # gymnax # navix -fonttools==4.53.1 +fonttools==4.54.1 # via matplotlib frozenlist==1.4.1 # via @@ -245,7 +245,7 @@ gpytorch==1.13 # via # -r benchmarks/recursiongfn/requirements.in # botorch -grpcio==1.66.1 +grpcio==1.66.2 # via # brax # tensorboard @@ -267,7 +267,7 @@ gymnax==0.0.8 # -r benchmarks/purejaxrl/requirements.in hjson==3.1.0 # via argklass -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -r benchmarks/timm/requirements.in # accelerate @@ -420,7 +420,7 @@ mypy-extensions==1.0.0 # via black navix==0.7.0 # via -r benchmarks/purejaxrl/requirements.in -ndindex==1.8 +ndindex==1.9.2 # via blosc2 nest-asyncio==1.6.0 # via orbax-checkpoint @@ -466,7 +466,6 @@ numpy==1.26.4 # navix # numexpr # opencv-python - # opt-einsum # optax # orbax-checkpoint # pandas @@ -501,7 +500,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # via # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via jax-cuda12-plugin nvidia-cuda-nvrtc-cu12==12.1.105 # via torch @@ -534,7 +533,7 @@ nvidia-nccl-cu12==2.20.5 # via # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # jax-cuda12-plugin # nvidia-cusolver-cu12 @@ -549,7 +548,7 @@ omegaconf==2.3.0 # voir opencv-python==4.10.0.84 # via -r benchmarks/vjepa/requirements.in -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # jax # pyro-ppl @@ -584,7 +583,7 @@ packaging==24.1 # tensorboardx # torchmetrics # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in @@ -637,13 +636,13 @@ pyarrow==17.0.0 # datasets pycodestyle==2.12.1 # via flake8 -pycryptodomex==3.20.0 +pycryptodomex==3.21.0 # via blobfile pyflakes==3.2.0 # via flake8 pygments==2.18.0 # via rich -pylint==3.2.7 +pylint==3.3.1 # via navix pyopengl==3.1.7 # via mujoco @@ -711,7 +710,7 @@ requests==2.32.3 # torch-geometric # transformers # wandb -rich==13.8.1 +rich==13.9.1 # via # flax # tyro @@ -749,7 +748,7 @@ sentencepiece==0.2.0 # via # -r benchmarks/llama/requirements.in # torchtune -sentry-sdk==2.14.0 +sentry-sdk==2.15.0 # via wandb setproctitle==1.3.3 # via wandb @@ -761,7 +760,6 @@ six==1.16.0 # via # asttokens # docker-pycreds - # fire # ml-collections # python-dateutil # tensorboard @@ -778,7 +776,7 @@ tables==3.10.1 # via -r benchmarks/recursiongfn/requirements.in tabulate==0.9.0 # via fvcore -tensorboard==2.17.1 +tensorboard==2.18.0 # via # -r benchmarks/recursiongfn/requirements.in # -r benchmarks/torchatari/requirements.in @@ -788,7 +786,7 @@ tensorboardx==2.6.2.2 # via brax tensorflow-probability==0.24.0 # via distrax -tensorstore==0.1.65 +tensorstore==0.1.66 # via # flashbax # flax @@ -805,7 +803,7 @@ timm==1.0.9 # via -r benchmarks/vjepa/requirements.in tokenizers==0.19.1 # via transformers -tomli==2.0.1 +tomli==2.0.2 # via # black # pylint @@ -852,7 +850,7 @@ torch-cluster==1.6.3+pt24cu121 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in -torch-geometric==2.6.0 +torch-geometric==2.6.1 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in @@ -864,8 +862,11 @@ torch-sparse==0.6.18+pt24cu121 # via # -r benchmarks/geo_gnn/requirements.in # -r benchmarks/recursiongfn/requirements.in -torchao==0.5.0+cu121 - # via -r benchmarks/llm/requirements.in +torchao==0.3.1+cu121 + # via + # -c .pin/../constraints/cuda.txt + # -r benchmarks/llm/requirements.in + # torchtune torchcompat==1.1.4 # via # -c .pin/../constraints/cuda.txt @@ -879,8 +880,10 @@ torchmetrics==1.4.2 # -r benchmarks/dinov2/requirements.in # lightning # pytorch-lightning -torchtune==0.3.0+cu121 - # via -r benchmarks/llm/requirements.in +torchtune==0.2.1+cu121 + # via + # -c .pin/../constraints/cuda.txt + # -r benchmarks/llm/requirements.in torchvision==0.19.0+cu121 # via # -r benchmarks/diffusion/requirements.in @@ -910,6 +913,7 @@ tqdm==4.66.5 # transformers transformers==4.44.2 # via + # -c .pin/../constraints/cuda.txt # -r benchmarks/diffusion/requirements.in # -r benchmarks/huggingface/requirements.in # -r benchmarks/llama/requirements.in @@ -923,11 +927,13 @@ trimesh==4.4.9 # mujoco-mjx triton==3.0.0 # via torch -trl==0.11.0 - # via -r benchmarks/rlhf/requirements.in +trl==0.10.1 + # via + # -c .pin/../constraints/cuda.txt + # -r benchmarks/rlhf/requirements.in typeguard==4.3.0 # via jaxtyping -types-protobuf==5.27.0.20240907 +types-protobuf==5.28.0.20240924 # via envpool typing-extensions==4.12.2 # via @@ -952,6 +958,7 @@ typing-extensions==4.12.2 # orbax-checkpoint # pytorch-lightning # reactivex + # rich # submitit # tables # torch @@ -962,7 +969,7 @@ tyro==0.8.11 # -r benchmarks/torchatari/requirements.in # navix # trl -tzdata==2024.1 +tzdata==2024.2 # via pandas urllib3==2.2.3 # via @@ -992,7 +999,7 @@ voir==0.2.19 # -r benchmarks/torchvision/requirements.in # -r benchmarks/torchvision_ddp/requirements.in # -r benchmarks/vjepa/requirements.in -wandb==0.18.1 +wandb==0.18.3 # via # -r benchmarks/recursiongfn/requirements.in # navix @@ -1010,7 +1017,7 @@ xxhash==3.5.0 # via datasets yacs==0.1.8 # via fvcore -yarl==1.11.1 +yarl==1.13.1 # via aiohttp zipp==3.20.2 # via diff --git a/benchmarks/brax/requirements.cuda.txt b/benchmarks/brax/requirements.cuda.txt index 89ebe8840..aae485613 100644 --- a/benchmarks/brax/requirements.cuda.txt +++ b/benchmarks/brax/requirements.cuda.txt @@ -37,7 +37,7 @@ brax==0.10.5 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/brax/requirements.in -chex==0.1.86 +chex==0.1.87 # via # -c .pin/../.pin/constraints-cuda-torch.txt # optax @@ -109,7 +109,7 @@ glfw==2.7.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # mujoco -grpcio==1.66.1 +grpcio==1.66.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -234,7 +234,6 @@ numpy==1.26.4 # jaxopt # ml-dtypes # mujoco - # opt-einsum # optax # orbax-checkpoint # scipy @@ -254,7 +253,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -301,7 +300,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -315,7 +314,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -377,7 +376,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flax @@ -403,7 +402,7 @@ tensorboardx==2.6.2.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax -tensorstore==0.1.65 +tensorstore==0.1.66 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flax @@ -435,6 +434,7 @@ typing-extensions==4.12.2 # flax # orbax-checkpoint # reactivex + # rich # torch varname==0.13.3 # via diff --git a/benchmarks/diffusion/requirements.cuda.txt b/benchmarks/diffusion/requirements.cuda.txt index 34a92c65d..676489f43 100644 --- a/benchmarks/diffusion/requirements.cuda.txt +++ b/benchmarks/diffusion/requirements.cuda.txt @@ -15,11 +15,11 @@ accelerate==0.34.2 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in # diffusers -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -60,7 +60,7 @@ codefind==0.1.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # ptera -datasets==3.0.0 +datasets==3.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/diffusion/requirements.in @@ -106,7 +106,7 @@ hjson==3.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -190,7 +190,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # pyarrow # scipy @@ -209,7 +208,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -256,7 +255,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -270,7 +269,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -285,7 +284,7 @@ packaging==24.1 # datasets # huggingface-hub # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -343,7 +342,7 @@ requests==2.32.3 # diffusers # huggingface-hub # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -392,6 +391,7 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/diffusion/requirements.in triton==3.0.0 # via @@ -403,8 +403,9 @@ typing-extensions==4.12.2 # huggingface-hub # multidict # reactivex + # rich # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -429,7 +430,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/dinov2/requirements.cuda.txt b/benchmarks/dinov2/requirements.cuda.txt index 9b3940ff2..bb0535894 100644 --- a/benchmarks/dinov2/requirements.cuda.txt +++ b/benchmarks/dinov2/requirements.cuda.txt @@ -109,7 +109,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchmetrics # torchvision @@ -126,7 +125,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -173,7 +172,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -188,7 +187,7 @@ omegaconf==2.3.0 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/dinov2/requirements.in # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -232,7 +231,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -292,6 +291,7 @@ typing-extensions==4.12.2 # iopath # lightning-utilities # reactivex + # rich # submitit # torch varname==0.13.3 diff --git a/benchmarks/flops/requirements.cuda.txt b/benchmarks/flops/requirements.cuda.txt index e529152e3..fd027a8fb 100644 --- a/benchmarks/flops/requirements.cuda.txt +++ b/benchmarks/flops/requirements.cuda.txt @@ -95,7 +95,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchvision # xformers @@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -172,7 +171,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -204,7 +203,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -248,6 +247,7 @@ typing-extensions==4.12.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # reactivex + # rich # torch varname==0.13.3 # via diff --git a/benchmarks/geo_gnn/requirements-pre.cuda.txt b/benchmarks/geo_gnn/requirements-pre.cuda.txt index 6c76b0c91..f56bb4988 100644 --- a/benchmarks/geo_gnn/requirements-pre.cuda.txt +++ b/benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -62,7 +62,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # xformers nvidia-cublas-cu12==12.1.3.1 @@ -77,7 +76,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -120,7 +119,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -130,7 +129,7 @@ nvidia-nvtx-cu12==12.1.105 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax diff --git a/benchmarks/geo_gnn/requirements.cuda.txt b/benchmarks/geo_gnn/requirements.cuda.txt index 37b77babb..c4ffaa639 100644 --- a/benchmarks/geo_gnn/requirements.cuda.txt +++ b/benchmarks/geo_gnn/requirements.cuda.txt @@ -10,11 +10,11 @@ --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html --trusted-host pypi.ngc.nvidia.com -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch-geometric @@ -149,7 +149,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # rdkit # scipy @@ -169,7 +168,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -r benchmarks/geo_gnn/requirements-pre.cuda.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -225,7 +224,7 @@ nvidia-nccl-cu12==2.20.5 # -r benchmarks/geo_gnn/requirements-pre.cuda.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -241,7 +240,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt @@ -250,7 +249,7 @@ ovld==0.3.9 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements.in @@ -299,7 +298,7 @@ requests==2.32.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch-geometric -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -330,7 +329,7 @@ torch-cluster==1.6.3+pt24cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements.in -torch-geometric==2.6.0 +torch-geometric==2.6.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/geo_gnn/requirements.in @@ -357,8 +356,9 @@ typing-extensions==4.12.2 # -r benchmarks/geo_gnn/requirements-pre.cuda.txt # multidict # reactivex + # rich # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -380,7 +380,7 @@ xformers==0.0.27.post2 # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt # -r benchmarks/geo_gnn/requirements-pre.cuda.txt -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/huggingface/requirements.cuda.txt b/benchmarks/huggingface/requirements.cuda.txt index d4323b4af..45e68e325 100644 --- a/benchmarks/huggingface/requirements.cuda.txt +++ b/benchmarks/huggingface/requirements.cuda.txt @@ -51,7 +51,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tokenizers @@ -111,7 +111,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # transformers # xformers @@ -127,7 +126,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -174,7 +173,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -188,7 +187,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -236,7 +235,7 @@ requests==2.32.3 # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -274,6 +273,7 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/huggingface/requirements.in triton==3.0.0 # via @@ -284,6 +284,7 @@ typing-extensions==4.12.2 # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub # reactivex + # rich # torch urllib3==2.2.3 # via diff --git a/benchmarks/lightning/requirements.cuda.txt b/benchmarks/lightning/requirements.cuda.txt index db0745882..04b4eb4b3 100644 --- a/benchmarks/lightning/requirements.cuda.txt +++ b/benchmarks/lightning/requirements.cuda.txt @@ -10,11 +10,11 @@ --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html --trusted-host pypi.ngc.nvidia.com -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # fsspec @@ -141,7 +141,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchmetrics # torchvision @@ -158,7 +157,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -205,7 +204,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -219,7 +218,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -264,7 +263,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -321,6 +320,7 @@ typing-extensions==4.12.2 # multidict # pytorch-lightning # reactivex + # rich # torch varname==0.13.3 # via @@ -335,7 +335,7 @@ xformers==0.0.27.post2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llama/requirements.cuda.txt b/benchmarks/llama/requirements.cuda.txt index 1f52de100..0b3188482 100644 --- a/benchmarks/llama/requirements.cuda.txt +++ b/benchmarks/llama/requirements.cuda.txt @@ -10,11 +10,11 @@ --find-links https://data.pyg.org/whl/torch-2.4.0+cu121.html --trusted-host pypi.ngc.nvidia.com -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -51,7 +51,7 @@ codefind==0.1.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # ptera -datasets==3.0.0 +datasets==3.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llama/requirements.in @@ -76,7 +76,7 @@ filelock==3.16.1 # torch # transformers # triton -fire==0.6.0 +fire==0.7.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llama/requirements.in @@ -96,7 +96,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -169,7 +169,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # pyarrow # scipy @@ -187,7 +186,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -234,7 +233,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -248,7 +247,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -262,7 +261,7 @@ packaging==24.1 # datasets # huggingface-hub # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -311,7 +310,7 @@ requests==2.32.3 # datasets # huggingface-hub # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -332,7 +331,6 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens - # fire # python-dateutil sympy==1.13.3 # via @@ -361,6 +359,7 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/llama/requirements.in triton==3.0.0 # via @@ -372,8 +371,9 @@ typing-extensions==4.12.2 # huggingface-hub # multidict # reactivex + # rich # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -398,7 +398,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llava/requirements.cuda.txt b/benchmarks/llava/requirements.cuda.txt index 91e94c4bf..5c6f9f64b 100644 --- a/benchmarks/llava/requirements.cuda.txt +++ b/benchmarks/llava/requirements.cuda.txt @@ -14,11 +14,11 @@ accelerate==0.34.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llava/requirements.in -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -55,7 +55,7 @@ codefind==0.1.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # ptera -datasets==3.0.0 +datasets==3.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llava/requirements.in @@ -92,7 +92,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -167,7 +167,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # pyarrow # scipy @@ -185,7 +184,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -232,7 +231,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -246,7 +245,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -261,7 +260,7 @@ packaging==24.1 # datasets # huggingface-hub # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -316,7 +315,7 @@ requests==2.32.3 # datasets # huggingface-hub # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -358,6 +357,7 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/llava/requirements.in triton==3.0.0 # via @@ -369,8 +369,9 @@ typing-extensions==4.12.2 # huggingface-hub # multidict # reactivex + # rich # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -395,7 +396,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llm/requirements.cuda.txt b/benchmarks/llm/requirements.cuda.txt index 3abff6b50..db34901fd 100644 --- a/benchmarks/llm/requirements.cuda.txt +++ b/benchmarks/llm/requirements.cuda.txt @@ -14,11 +14,11 @@ accelerate==0.34.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.in -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -64,7 +64,7 @@ codefind==0.1.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # ptera -datasets==3.0.0 +datasets==3.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torchtune @@ -91,7 +91,7 @@ filelock==3.16.1 # torch # transformers # triton -fire==0.6.0 +fire==0.7.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/llm/requirements.txt @@ -115,7 +115,7 @@ hjson==3.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # argklass -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -199,7 +199,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # pyarrow # scipy @@ -218,7 +217,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -265,7 +264,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -280,7 +279,7 @@ omegaconf==2.3.0 # -c .pin/../.pin/constraints-cuda-torch.txt # torchtune # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -295,7 +294,7 @@ packaging==24.1 # datasets # huggingface-hub # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -312,7 +311,7 @@ pyarrow==17.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -pycryptodomex==3.20.0 +pycryptodomex==3.21.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # blobfile @@ -353,7 +352,7 @@ requests==2.32.3 # huggingface-hub # tiktoken # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -376,7 +375,6 @@ six==1.16.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # asttokens - # fire # python-dateutil sympy==1.13.3 # via @@ -402,13 +400,16 @@ torch==2.4.0+cu121 # accelerate # fairscale # xformers -torchao==0.5.0+cu121 +torchao==0.3.1+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/llm/requirements.in -torchtune==0.3.0+cu121 + # torchtune +torchtune==0.2.1+cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/llm/requirements.in tqdm==4.66.5 # via @@ -420,6 +421,7 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/llm/requirements.in triton==3.0.0 # via @@ -431,8 +433,9 @@ typing-extensions==4.12.2 # huggingface-hub # multidict # reactivex + # rich # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -458,7 +461,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/llm/requirements.in b/benchmarks/llm/requirements.in index 36832ad67..a3ab63c07 100644 --- a/benchmarks/llm/requirements.in +++ b/benchmarks/llm/requirements.in @@ -1,5 +1,5 @@ voir>=0.2.19,<0.3 -torchtune +torchtune<0.3.0 torch PyYAML argklass diff --git a/benchmarks/purejaxrl/requirements.cuda.txt b/benchmarks/purejaxrl/requirements.cuda.txt index d495163a9..3f09e47f7 100644 --- a/benchmarks/purejaxrl/requirements.cuda.txt +++ b/benchmarks/purejaxrl/requirements.cuda.txt @@ -32,7 +32,7 @@ argklass==1.4.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/purejaxrl/requirements.in -astroid==3.2.4 +astroid==3.3.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pylint @@ -61,7 +61,7 @@ charset-normalizer==3.3.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # requests -chex==0.1.86 +chex==0.1.87 # via # -c .pin/../.pin/constraints-cuda-torch.txt # distrax @@ -188,7 +188,7 @@ flax==0.9.0 # flashbax # gymnax # navix -fonttools==4.53.1 +fonttools==4.54.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # matplotlib @@ -218,7 +218,7 @@ glfw==2.7.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # mujoco -grpcio==1.66.1 +grpcio==1.66.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # brax @@ -411,7 +411,6 @@ numpy==1.26.4 # ml-dtypes # mujoco # navix - # opt-einsum # optax # orbax-checkpoint # pandas @@ -435,7 +434,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -482,7 +481,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -496,7 +495,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -523,7 +522,7 @@ packaging==24.1 # pytest # setuptools-scm # tensorboardx -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # seaborn @@ -574,7 +573,7 @@ pygments==2.18.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # rich -pylint==3.2.7 +pylint==3.3.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix @@ -621,7 +620,7 @@ requests==2.32.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # wandb -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flax @@ -643,7 +642,7 @@ seaborn==0.13.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # gymnax -sentry-sdk==2.14.0 +sentry-sdk==2.15.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # wandb @@ -683,13 +682,13 @@ tensorflow-probability==0.24.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # distrax -tensorstore==0.1.65 +tensorstore==0.1.66 # via # -c .pin/../.pin/constraints-cuda-torch.txt # flashbax # flax # orbax-checkpoint -tomli==2.0.1 +tomli==2.0.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # black @@ -732,13 +731,14 @@ typing-extensions==4.12.2 # navix # orbax-checkpoint # reactivex + # rich # torch # tyro tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -756,7 +756,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-torch.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/purejaxrl/requirements.in -wandb==0.18.1 +wandb==0.18.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # navix diff --git a/benchmarks/recursiongfn/requirements.cuda.txt b/benchmarks/recursiongfn/requirements.cuda.txt index 2c852b71d..497f573ab 100644 --- a/benchmarks/recursiongfn/requirements.cuda.txt +++ b/benchmarks/recursiongfn/requirements.cuda.txt @@ -14,11 +14,11 @@ absl-py==2.1.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tensorboard -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch-geometric @@ -113,7 +113,7 @@ gpytorch==1.13 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in # botorch -grpcio==1.66.1 +grpcio==1.66.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tensorboard @@ -199,7 +199,7 @@ multipledispatch==1.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # botorch -ndindex==1.8 +ndindex==1.9.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # blosc2 @@ -222,7 +222,6 @@ numpy==1.26.4 # jaxtyping # ml-dtypes # numexpr - # opt-einsum # pandas # pyarrow # pyro-ppl @@ -245,7 +244,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -292,7 +291,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -307,7 +306,7 @@ omegaconf==2.3.0 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -321,7 +320,7 @@ packaging==24.1 # -c .pin/../.pin/constraints-cuda-torch.txt # tables # tensorboard -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -400,7 +399,7 @@ requests==2.32.3 # -c .pin/../.pin/constraints-cuda-torch.txt # torch-geometric # wandb -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -420,7 +419,7 @@ scipy==1.14.1 # scikit-learn # torch-cluster # torch-sparse -sentry-sdk==2.14.0 +sentry-sdk==2.15.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # wandb @@ -447,7 +446,7 @@ tables==3.10.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in -tensorboard==2.17.1 +tensorboard==2.18.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -471,7 +470,7 @@ torch-cluster==1.6.3+pt24cu121 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in -torch-geometric==2.6.0 +torch-geometric==2.6.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -503,10 +502,11 @@ typing-extensions==4.12.2 # jaxtyping # multidict # reactivex + # rich # tables # torch # typeguard -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -524,7 +524,7 @@ voir==0.2.19 # -c .pin/../.pin/constraints-cuda-torch.txt # -c .pin/../constraints/cuda.txt # -r benchmarks/recursiongfn/requirements.in -wandb==0.18.1 +wandb==0.18.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/recursiongfn/requirements.in @@ -536,7 +536,7 @@ xformers==0.0.27.post2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r .pin/../constraints/extra/torch.cuda.txt -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/rlhf/requirements.cuda.txt b/benchmarks/rlhf/requirements.cuda.txt index acc448aee..dee2ae27c 100644 --- a/benchmarks/rlhf/requirements.cuda.txt +++ b/benchmarks/rlhf/requirements.cuda.txt @@ -15,11 +15,11 @@ accelerate==0.34.2 # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/rlhf/requirements.in # trl -aiohappyeyeballs==2.4.0 +aiohappyeyeballs==2.4.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp -aiohttp==3.10.5 +aiohttp==3.10.8 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -56,7 +56,7 @@ codefind==0.1.7 # via # -c .pin/../.pin/constraints-cuda-torch.txt # ptera -datasets==3.0.0 +datasets==3.0.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/rlhf/requirements.in @@ -98,7 +98,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # accelerate @@ -172,7 +172,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # pandas # pyarrow # scipy @@ -191,7 +190,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -238,7 +237,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -252,7 +251,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -267,7 +266,7 @@ packaging==24.1 # datasets # huggingface-hub # transformers -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets @@ -318,7 +317,7 @@ requests==2.32.3 # datasets # huggingface-hub # transformers -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tyro @@ -366,15 +365,17 @@ tqdm==4.66.5 transformers==4.44.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/rlhf/requirements.in # trl triton==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -trl==0.11.0 +trl==0.10.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt + # -c .pin/../constraints/cuda.txt # -r benchmarks/rlhf/requirements.in typing-extensions==4.12.2 # via @@ -382,13 +383,14 @@ typing-extensions==4.12.2 # huggingface-hub # multidict # reactivex + # rich # torch # tyro tyro==0.8.11 # via # -c .pin/../.pin/constraints-cuda-torch.txt # trl -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas @@ -413,7 +415,7 @@ xxhash==3.5.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # datasets -yarl==1.11.1 +yarl==1.13.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # aiohttp diff --git a/benchmarks/timm/requirements.cuda.txt b/benchmarks/timm/requirements.cuda.txt index 1ac873600..b55428950 100644 --- a/benchmarks/timm/requirements.cuda.txt +++ b/benchmarks/timm/requirements.cuda.txt @@ -50,7 +50,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/timm/requirements.in @@ -109,7 +109,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchvision # xformers @@ -125,7 +124,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -172,7 +171,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -186,7 +185,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -228,7 +227,7 @@ requests==2.32.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -272,6 +271,7 @@ typing-extensions==4.12.2 # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub # reactivex + # rich # torch urllib3==2.2.3 # via diff --git a/benchmarks/torchatari/requirements.cuda.txt b/benchmarks/torchatari/requirements.cuda.txt index 0ed7f915c..1be36a969 100644 --- a/benchmarks/torchatari/requirements.cuda.txt +++ b/benchmarks/torchatari/requirements.cuda.txt @@ -78,7 +78,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -grpcio==1.66.1 +grpcio==1.66.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tensorboard @@ -161,7 +161,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # tensorboard # xformers @@ -177,7 +176,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -224,7 +223,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -238,7 +237,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -279,7 +278,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # tyro @@ -302,7 +301,7 @@ sympy==1.13.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -tensorboard==2.17.1 +tensorboard==2.18.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/torchatari/requirements.in @@ -324,7 +323,7 @@ triton==3.0.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # torch -types-protobuf==5.27.0.20240907 +types-protobuf==5.28.0.20240924 # via # -c .pin/../.pin/constraints-cuda-torch.txt # envpool @@ -335,6 +334,7 @@ typing-extensions==4.12.2 # gymnasium # optree # reactivex + # rich # torch # tyro tyro==0.8.11 diff --git a/benchmarks/torchvision/requirements.cuda.txt b/benchmarks/torchvision/requirements.cuda.txt index 3b994c798..108cc0e69 100644 --- a/benchmarks/torchvision/requirements.cuda.txt +++ b/benchmarks/torchvision/requirements.cuda.txt @@ -95,7 +95,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchvision # xformers @@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -172,7 +171,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -204,7 +203,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -248,6 +247,7 @@ typing-extensions==4.12.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # reactivex + # rich # torch varname==0.13.3 # via diff --git a/benchmarks/torchvision_ddp/requirements.cuda.txt b/benchmarks/torchvision_ddp/requirements.cuda.txt index 4e6a2a2b8..8572482df 100644 --- a/benchmarks/torchvision_ddp/requirements.cuda.txt +++ b/benchmarks/torchvision_ddp/requirements.cuda.txt @@ -95,7 +95,6 @@ numpy==1.26.4 # jax # jaxlib # ml-dtypes - # opt-einsum # scipy # torchvision # xformers @@ -111,7 +110,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -158,7 +157,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -172,7 +171,7 @@ omegaconf==2.3.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -204,7 +203,7 @@ reactivex==4.0.4 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -248,6 +247,7 @@ typing-extensions==4.12.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # reactivex + # rich # torch varname==0.13.3 # via diff --git a/benchmarks/vjepa/requirements.cuda.txt b/benchmarks/vjepa/requirements.cuda.txt index 2386bbd24..867c50b53 100644 --- a/benchmarks/vjepa/requirements.cuda.txt +++ b/benchmarks/vjepa/requirements.cuda.txt @@ -18,7 +18,7 @@ asttokens==2.4.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # giving -beartype==0.18.5 +beartype==0.19.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/vjepa/requirements.in @@ -71,7 +71,7 @@ giving==0.4.3 # -c .pin/../.pin/constraints-cuda-torch.txt # ptera # voir -huggingface-hub==0.25.0 +huggingface-hub==0.25.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # timm @@ -133,7 +133,6 @@ numpy==1.26.4 # jaxlib # ml-dtypes # opencv-python - # opt-einsum # pandas # scipy # torchvision @@ -151,7 +150,7 @@ nvidia-cuda-cupti-cu12==12.1.105 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-cuda-nvcc-cu12==12.6.68 +nvidia-cuda-nvcc-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -198,7 +197,7 @@ nvidia-nccl-cu12==2.20.5 # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin # torch -nvidia-nvjitlink-cu12==12.6.68 +nvidia-nvjitlink-cu12==12.6.77 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax-cuda12-plugin @@ -216,7 +215,7 @@ opencv-python==4.10.0.84 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/vjepa/requirements.in -opt-einsum==3.3.0 +opt-einsum==3.4.0 # via # -c .pin/../.pin/constraints-cuda-torch.txt # jax @@ -228,7 +227,7 @@ packaging==24.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub -pandas==2.2.2 +pandas==2.2.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # -r benchmarks/vjepa/requirements.in @@ -272,7 +271,7 @@ requests==2.32.3 # via # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub -rich==13.8.1 +rich==13.9.1 # via # -c .pin/../.pin/constraints-cuda-torch.txt # voir @@ -327,9 +326,10 @@ typing-extensions==4.12.2 # -c .pin/../.pin/constraints-cuda-torch.txt # huggingface-hub # reactivex + # rich # submitit # torch -tzdata==2024.1 +tzdata==2024.2 # via # -c .pin/../.pin/constraints-cuda-torch.txt # pandas diff --git a/constraints/cuda.txt b/constraints/cuda.txt index eb6bbcedf..49675b577 100644 --- a/constraints/cuda.txt +++ b/constraints/cuda.txt @@ -5,3 +5,14 @@ voir >= 0.2.19 torchcompat >= 1.0.0 gymnax >= 0.0.8 +trl<0.11.0 + +# latest torchtune is slower than before and cause failures +# next version of pytorch seems to work better +# so pending a new version of pytorch this is what we get +torchtune<0.3.0 + +# transformers added torchao support recently +# but only the most recent version we do not support +transformers<4.45.0 +torchao \ No newline at end of file diff --git a/milabench/_version.py b/milabench/_version.py index 4da614fc7..cdd2418dd 100644 --- a/milabench/_version.py +++ b/milabench/_version.py @@ -1,5 +1,5 @@ """This file is generated, do not modify""" -__tag__ = "v1.0.0_RC1-6-g4639c19d" -__commit__ = "4639c19d34c8b15349e2dd265a493374c84ed3aa" -__date__ = "2024-09-25 15:22:55 +0000" +__tag__ = "v1.0.0_RC1-9-g6d1e1140" +__commit__ = "6d1e114000cc4200ea307330032234db6696e40d" +__date__ = "2024-09-30 14:39:43 -0400" diff --git a/milabench/report.py b/milabench/report.py index aebcaf093..c54ed8ddd 100644 --- a/milabench/report.py +++ b/milabench/report.py @@ -342,6 +342,35 @@ def short_meta(out, meta): out.print(Table(stats)) +def to_latex(df): + from dataclasses import dataclass + from .system import option + + default_columns = [ + "ngpu", + "perf", + "sem%", + "std%" + ] + + @dataclass + class LatexTable: + output: str = option("latex.output", str, None) + columns: str = option("latex.columns", str, ",".join(default_columns)) + + options = LatexTable() + + columns = options.columns.split(",") + + df = df[columns] + + if options.output is not None: + with open(options.output, "w") as fp: + txt = df.to_latex(formatters=_formatters, escape=False) + txt = txt.replace("%", "\%").replace("_", "\_") + fp.write(txt) + + @error_guard({}) def make_report( summary: dict[str, Summary], @@ -376,7 +405,10 @@ def make_report( out.section("Breakdown") # Reorder columns - out.print(normalize_dataframe(df)) + normalized = normalize_dataframe(df) + out.print(normalized) + + to_latex(normalized) out.section("Scores") diff --git a/scripts/article/run_cuda.sh b/scripts/article/run_cuda.sh index 7328ca54b..ba4c1ae38 100644 --- a/scripts/article/run_cuda.sh +++ b/scripts/article/run_cuda.sh @@ -49,7 +49,7 @@ install_prepare() { # Install milabench's benchmarks in their venv # # pip install torch - # milabench pin --variant cuda --from-scratch $ARGS + milabench pin --variant cuda --from-scratch $ARGS milabench install --system $MILABENCH_WORDIR/system.yaml $ARGS which pip @@ -84,8 +84,19 @@ if [ "$MILABENCH_PREPARE" -eq 0 ]; then . $MILABENCH_WORDIR/env/bin/activate - milabench install --system $MILABENCH_WORDIR/system.yaml - # milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS + # pip install torch + # milabench pin --variant cuda --from-scratch + # rm -rf $MILABENCH_WORDIR/results/venv/ + # rm -rf $MILABENCH_WORDIR/results/extra + # milabench install --system $MILABENCH_WORDIR/system.yaml + milabench prepare --system $MILABENCH_WORDIR/system.yaml $ARGS + + ( + . $BENCHMARK_VENV/bin/activate + which pip + # pip uninstall torchao -y + # pip install torchao --no-input + ) # pip install torch # milabench pin --variant cuda --from-scratch