diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 54f84ae7..1639bb3a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -533,12 +533,23 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict: return args_dict def _validate_config(self) -> None: - # TODO: Add more mandatory validations assert self.model.name assert self.model.flavor assert self.model.tokenizer_path - ac_config = self.activation_checkpoint + pp_split_mode = self.experimental.pipeline_parallel_split_mode + if pp_split_mode not in ("manual", "tracer"): + raise ValueError( + f"Invalid split mode: {self.experimental.pipeline_parallel_split_mode}" + ) + if pp_split_mode == "tracer" and self.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr + # invocation stride in strict mode from `if dy.stride(-1) != 1:` in + # fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm." + ) + ac_config = self.activation_checkpoint if ac_config.mode not in ("full", "selective", "none"): raise ValueError(f"Invalid AC mode: {ac_config.mode}") @@ -549,6 +560,11 @@ def _validate_config(self) -> None: f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}" ) + if self.training.compile and self.model.norm_type == "fused_rmsnorm": + raise NotImplementedError( + "fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm." + ) + def parse_args_from_command_line( self, args_list ) -> Tuple[argparse.Namespace, argparse.Namespace]: diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 2554d01c..96c1b3b0 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -9,9 +9,11 @@ import copy from collections import defaultdict -from typing import Dict, Tuple +from typing import Tuple, Union import torch +import torch.nn as nn +from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed._tensor import Replicate, Shard @@ -29,8 +31,13 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging_utils import logger +from torchtitan.models.llama.model import ModelArgs +from torchtitan.parallelisms import ParallelDims from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank + +DeviceType = Union[int, str, torch.device] + # for selective AC no_recompute_list = { torch.ops.aten.mm.default, @@ -107,23 +114,27 @@ def get_tp_parallel_strategy( def pipeline_llama( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): - if job_config.experimental.pipeline_parallel_split_mode == "manual": + split_mode = job_config.experimental.pipeline_parallel_split_mode + if split_mode == "manual": return pipeline_llama_manual( model, world_mesh, parallel_dims, job_config, device, model_config ) - elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + elif split_mode == "tracer": return pipeline_llama_tracer( model, world_mesh, parallel_dims, job_config, device, model_config ) else: - raise NotImplementedError( - f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" - ) + raise NotImplementedError(f"{split_mode} is not a valid split mode") -def _llama_trace_input(job_config, model_config, device="meta"): +def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"): """Get meta tensors with the right input shapes used for tracing""" tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) tokens = torch.randint( @@ -135,18 +146,18 @@ def _llama_trace_input(job_config, model_config, device="meta"): def _mixed_precision_dtype( job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32 ) -> torch.dtype: - """Get the mixed precision dtype if fsdp is enabled, otherwise return the default""" + """Get the mixed precision dtype if FSDP is enabled, otherwise return the default""" mp_arg = job_config.training.mixed_precision_param return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default def pipeline_llama_manual( - whole_model, - world_mesh, - parallel_dims, + whole_model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, job_config: JobConfig, - device, - model_config: Dict, + device: DeviceType, + model_config: ModelArgs, ): """ This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. @@ -244,19 +255,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal def pipeline_llama_tracer( - model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: DeviceType, + model_config: ModelArgs, ): - if job_config.model.norm_type == "fused_rmsnorm": - # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode - # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: raise NotImplementedError( - "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." - ) - - if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16: - raise NotImplementedError( - "pipeline tracer doesn't work with fsdp mixed precision currently. " - "To work around, edit fsdp mixed precision config to use fp32." + "Pipeline tracer does not work with FSDP mixed precision yet. " + "To work around, set mixed_precision_param to float32." ) pp_mesh = world_mesh["pp"] @@ -292,10 +301,13 @@ def pipeline_llama_tracer( return (stages, models) -def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): - """ - Apply tensor parallelism. - """ +def apply_tp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """Apply tensor parallelism.""" tp_mesh = world_mesh["tp"] # Parallel styles for transformer block linear weights may be different for @@ -374,10 +386,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig): return model -def apply_ac(model, job_config: JobConfig): - """ - Apply activation checkpointing to the model. - """ +def apply_ac(model: nn.Module, job_config: JobConfig): + """Apply activation checkpointing to the model.""" ac_config = job_config.activation_checkpoint @@ -389,18 +399,10 @@ def apply_ac(model, job_config: JobConfig): return model -def apply_compile(model, job_config: JobConfig): - """ - Apply torch.compile to the model. - """ - - if job_config.model.norm_type == "fused_rmsnorm": - raise NotImplementedError( - "fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm." - ) +def apply_compile(model: nn.Module, job_config: JobConfig): + """Apply torch.compile to each transformer block.""" for layer_id, transformer_block in model.layers.named_children(): - # turn on per-transformer block compile after AC wrapping and before FSDP # TODO: dynamic shape have some issues so we turn it off for now. # TODO: inline inbuilt nn modules does not work yet, enable it to accelarate # compile time. @@ -412,10 +414,13 @@ def apply_compile(model, job_config: JobConfig): return model -def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): - """ - Apply data parallelism (FSDP2) to the model. - """ +def apply_dp( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """Apply data parallelism (FSDP2) to the model.""" dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names @@ -448,7 +453,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig): return model -def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): """ Apply tensor parallelism, activation checkpointing, torch.compile, and data parallelism to the model.