Skip to content

Commit

Permalink
Added type annotations and more stylistic changes
Browse files Browse the repository at this point in the history
ghstack-source-id: 251bbd533da3bf7722358574f0ad232252be2204
Pull Request resolved: #449
  • Loading branch information
awgu committed Jul 10, 2024
1 parent 8e6952c commit 70478e7
Showing 1 changed file with 69 additions and 42 deletions.
111 changes: 69 additions & 42 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

import copy
from collections import defaultdict
from typing import Dict, Tuple
from typing import Tuple, TYPE_CHECKING, Union

import torch
import torch.nn as nn
from torch.distributed import DeviceMesh

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
Expand All @@ -29,8 +31,15 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank

if TYPE_CHECKING:
from torchtitan.parallelisms import ParallelDims


DeviceType = Union[int, str, torch.device]

# for selective AC
no_recompute_list = {
torch.ops.aten.mm.default,
Expand Down Expand Up @@ -125,23 +134,30 @@ def get_tp_parallel_strategy(


def pipeline_llama(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.experimental.pipeline_parallel_split_mode == "manual":
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
if split_mode not in valid_split_modes:
raise ValueError(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
return pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config, device, model_config
)
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
elif split_mode == "tracer":
return pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config, device, model_config
)
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
)


def _llama_trace_input(job_config, model_config, device="meta"):
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
tokens = torch.randint(
Expand All @@ -153,18 +169,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
def _mixed_precision_dtype(
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual(
whole_model,
world_mesh,
parallel_dims,
whole_model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device,
model_config: Dict,
device: DeviceType,
model_config: ModelArgs,
):
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand Down Expand Up @@ -262,19 +278,24 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal


def pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
# fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
)

if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
"Pipeline tracer does not work with FSDP mixed precision yet. "
"To work around, set mixed_precision_param to float32."
)

pp_mesh = world_mesh["pp"]
Expand Down Expand Up @@ -310,10 +331,13 @@ def pipeline_llama_tracer(
return (stages, models)


def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply tensor parallelism.
"""
def apply_tp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
Expand Down Expand Up @@ -392,10 +416,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def apply_ac(model, job_config: JobConfig):
"""
Apply activation checkpointing to the model.
"""
def apply_ac(model: nn.Module, job_config: JobConfig):
"""Apply activation checkpointing to the model."""

ac_config = job_config.activation_checkpoint

Expand All @@ -407,18 +429,15 @@ def apply_ac(model, job_config: JobConfig):
return model


def apply_compile(model, job_config: JobConfig):
"""
Apply torch.compile to the model.
"""
def apply_compile(model: nn.Module, job_config: JobConfig):
"""Apply torch.compile to each transformer block."""

if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
)

for layer_id, transformer_block in model.layers.named_children():
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
Expand All @@ -430,10 +449,13 @@ def apply_compile(model, job_config: JobConfig):
return model


def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism (FSDP2) to the model.
"""
def apply_dp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""Apply data parallelism (FSDP2) to the model."""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -466,7 +488,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: "ParallelDims",
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
Expand Down

0 comments on commit 70478e7

Please sign in to comment.