Skip to content

Commit

Permalink
Moved more checks to config manager plus more stylistic changes
Browse files Browse the repository at this point in the history
ghstack-source-id: 32959badac2081ff6c0a91d54a7950adad7de720
Pull Request resolved: #449
  • Loading branch information
awgu committed Jul 10, 2024
1 parent c181683 commit 18b3653
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 55 deletions.
31 changes: 24 additions & 7 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,20 +533,37 @@ def _args_to_two_level_dict(self, args: argparse.Namespace) -> defaultdict:
return args_dict

def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name
assert self.model.flavor
assert self.model.tokenizer_path

pp_split_mode = self.experimental.pipeline_parallel_split_mode
if pp_split_mode not in ("manual", "tracer"):
raise ValueError(
f"Invalid split mode: {self.experimental.pipeline_parallel_split_mode}"
)
if pp_split_mode == "tracer" and self.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
# fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm is not compatible with Pipeline Tracer yet. Please use rmsnorm or layernorm."
)

ac_config = self.activation_checkpoint
assert (
ac_config.mode in ("full", "selective", "none")
), f"Unsupported AC mode: {ac_config.mode}"
if ac_config.mode not in ("full", "selective", "none"):
raise ValueError(f"Invalid AC mode: {ac_config.mode}")
if ac_config.mode == "selective" and ac_config.selective_ac_option.isdigit():
ac_freq = int(ac_config.selective_ac_option)
assert (
ac_freq > 0
), f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}"
if ac_freq <= 0:
raise ValueError(
f"Selective layer AC expects a positive int as selective_ac_option but got {ac_freq}"
)

if self.training.compile and self.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm is not compatible with torch.compile yet. Please use rmsnorm or layernorm."
)

def parse_args_from_command_line(
self, args_list
Expand Down
106 changes: 58 additions & 48 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import copy
from collections import defaultdict
from typing import Dict, Tuple
from typing import Tuple, Union

import torch
import torch.nn as nn

from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
Expand All @@ -26,10 +27,16 @@
RowwiseParallel,
SequenceParallel,
)
from torch.distributed import DeviceMesh

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging_utils import logger
from torchtitan.parallelisms import ParallelDims
from torchtitan.parallelisms.pipelining_utils import stage_ids_this_rank
from torchtitan.models.llama.model import ModelArgs


DeviceType = Union[int, str, torch.device]

# for selective AC
no_recompute_list = {
Expand Down Expand Up @@ -107,23 +114,27 @@ def get_tp_parallel_strategy(


def pipeline_llama(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs
):
if job_config.experimental.pipeline_parallel_split_mode == "manual":
split_mode = job_config.experimental.pipeline_parallel_split_mode
if split_mode == "manual":
return pipeline_llama_manual(
model, world_mesh, parallel_dims, job_config, device, model_config
)
elif job_config.experimental.pipeline_parallel_split_mode == "tracer":
elif split_mode == "tracer":
return pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config, device, model_config
)
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode"
)
raise NotImplementedError(f"{split_mode} is not a valid split mode")


def _llama_trace_input(job_config, model_config, device="meta"):
def _llama_trace_input(job_config: JobConfig, model_config: ModelArgs, device="meta"):
"""Get meta tensors with the right input shapes used for tracing"""
tokens_shape = (job_config.training.batch_size, job_config.training.seq_len)
tokens = torch.randint(
Expand All @@ -135,18 +146,18 @@ def _llama_trace_input(job_config, model_config, device="meta"):
def _mixed_precision_dtype(
job_config: JobConfig, parallel_dims, default: torch.dtype = torch.float32
) -> torch.dtype:
"""Get the mixed precision dtype if fsdp is enabled, otherwise return the default"""
"""Get the mixed precision dtype if FSDP is enabled, otherwise return the default"""
mp_arg = job_config.training.mixed_precision_param
return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default


def pipeline_llama_manual(
whole_model,
world_mesh,
parallel_dims,
whole_model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device,
model_config: Dict,
device: DeviceType,
model_config: ModelArgs,
):
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand Down Expand Up @@ -244,19 +255,17 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal


def pipeline_llama_tracer(
model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
raise NotImplementedError(
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
)

if _mixed_precision_dtype(job_config, parallel_dims) == torch.bfloat16:
raise NotImplementedError(
"pipeline tracer doesn't work with fsdp mixed precision currently. "
"To work around, edit fsdp mixed precision config to use fp32."
"Pipeline tracer does not work with FSDP mixed precision yet. "
"To work around, set mixed_precision_param to float32."
)

pp_mesh = world_mesh["pp"]
Expand Down Expand Up @@ -292,10 +301,13 @@ def pipeline_llama_tracer(
return (stages, models)


def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply tensor parallelism.
"""
def apply_tp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""Apply tensor parallelism."""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
Expand Down Expand Up @@ -374,10 +386,8 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def apply_ac(model, job_config: JobConfig):
"""
Apply activation checkpointing to the model.
"""
def apply_ac(model: nn.Module, job_config: JobConfig):
"""Apply activation checkpointing to the model."""

ac_config = job_config.activation_checkpoint

Expand All @@ -389,18 +399,10 @@ def apply_ac(model, job_config: JobConfig):
return model


def apply_compile(model, job_config: JobConfig):
"""
Apply torch.compile to the model.
"""

if job_config.model.norm_type == "fused_rmsnorm":
raise NotImplementedError(
"fused_rmsnorm not yet compatible with torch.compile. Please use layernorm or rmsnorm."
)
def apply_compile(model: nn.Module, job_config: JobConfig):
"""Apply torch.compile to each transformer block."""

for layer_id, transformer_block in model.layers.named_children():
# turn on per-transformer block compile after AC wrapping and before FSDP
# TODO: dynamic shape have some issues so we turn it off for now.
# TODO: inline inbuilt nn modules does not work yet, enable it to accelarate
# compile time.
Expand All @@ -412,10 +414,13 @@ def apply_compile(model, job_config: JobConfig):
return model


def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply data parallelism (FSDP2) to the model.
"""
def apply_dp(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""Apply data parallelism (FSDP2) to the model."""

dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
Expand Down Expand Up @@ -448,7 +453,12 @@ def apply_dp(model, world_mesh, parallel_dims, job_config: JobConfig):
return model


def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
def parallelize_llama(
model: nn.Module,
world_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply tensor parallelism, activation checkpointing, torch.compile, and data
parallelism to the model.
Expand Down

0 comments on commit 18b3653

Please sign in to comment.