Skip to content

Commit

Permalink
[BE][4/n] split pipeline_llama into a separate file
Browse files Browse the repository at this point in the history
ghstack-source-id: 5ebb4adf3152f413fa33a923c272c9aa3ce1f775
Pull Request resolved: #499
  • Loading branch information
tianyu-l committed Aug 5, 2024
1 parent c44cca0 commit 8849580
Show file tree
Hide file tree
Showing 9 changed files with 474 additions and 462 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ Our guiding principles when building `torchtitan`:

You may want to see how the model is defined or how parallelism techniques are applied. For a guided tour, see these files first:
* [train.py](https://github.com/pytorch/torchtitan/blob/main/train.py) - the main training loop and high-level setup code
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data / Tensor / Pipeline Parallelisms to the model
* [torchtitan/parallelisms/parallelize_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/parallelize_llama.py) - helpers for applying Data Parallel, Tensor Parallel, activation checkpointing, and `torch.compile` to the model
* [torchtitan/parallelisms/pipeline_llama.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
Expand All @@ -48,7 +50,7 @@ We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
### Coming soon

1. Async checkpointing
2. FP8 support
2. Float8 support
3. Context Parallel
4. 3D Pipeline Parallel
5. `torch.compile` support
Expand Down
8 changes: 4 additions & 4 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.datasets import build_tokenizer
from torchtitan.float8_linear import Float8Handler
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
Expand Down Expand Up @@ -124,9 +124,9 @@ def loss_fn(pred, labels):
with torch.device("meta"):
whole_model = model_cls.from_model_args(model_config)

# a no-op hander if fp8 is not enabled
# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear base on fp8 config
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(whole_model)

# apply PT-D DP/TP parallelisms and activation checkpointing
Expand Down Expand Up @@ -190,7 +190,7 @@ def loss_fn(pred, labels):
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_fp8_dynamic_scale_for_fsdp(model)
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
10 changes: 5 additions & 5 deletions torchtitan/float8_linear.py → torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchtitan.parallelisms import ParallelDims


def is_sm90_or_later():
def _is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)

Expand All @@ -33,7 +33,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
float8_config = job_config.float8
if not float8_config.enable_float8_linear:
return
if not is_sm90_or_later():
if not _is_sm90_or_later():
logger.warning(
"Failed to swap to Float8Linear because SM90 or later is not available",
)
Expand All @@ -42,7 +42,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use fp8 linear layers."
"torchao is not installed. Please install it to use float8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
Expand All @@ -64,7 +64,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

self.enabled = True

# for precompute_fp8_dynamic_scale_for_fsdp
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
Expand Down Expand Up @@ -103,7 +103,7 @@ def convert_to_float8_training(self, model: nn.Module):
f"{self.config.enable_fsdp_float8_all_gather}"
)

def precompute_fp8_dynamic_scale_for_fsdp(self, model: nn.Module):
def precompute_float8_dynamic_scale_for_fsdp(self, model: nn.Module):
if not self.enabled:
return

Expand Down
6 changes: 3 additions & 3 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ def _get_metrics_rank(parallel_dims: ParallelDims) -> int:


def build_metric_logger(
config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
job_config: JobConfig, parallel_dims: ParallelDims, tag: Optional[str] = None
):
"""
parallel_dims is used to determine the rank to log metrics from if 'tb_config.rank_0_only=True'.
In that case, `_get_metrics_rank` will be used to calculate which rank acts as 'rank 0'. This is
intended to allow logging from the 0th rank within the last pipeline stage group, in case pipeline
parallelism is enabled, without forcing logging from all ranks to capture loss information.
"""
dump_dir = config.job.dump_folder
tb_config = config.metrics
dump_dir = job_config.job.dump_folder
tb_config = job_config.metrics
save_tb_folder = tb_config.save_tb_folder
# since we don't have run id, use current minute as the identifier
datetime_str = datetime.now().strftime("%Y%m%d-%H%M")
Expand Down
67 changes: 3 additions & 64 deletions torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh
from torchtitan.logging import logger
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama
from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule


Expand All @@ -28,62 +26,3 @@
"llama2": pipeline_llama,
"llama3": pipeline_llama,
}


@dataclass
class ParallelDims:
dp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
):
if d > 1:
dims.append(d)
names.append(name)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)

@property
def dp_enabled(self):
return self.dp > 1

@property
def tp_enabled(self):
return self.tp > 1

@property
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.tp * self.pp
70 changes: 70 additions & 0 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from functools import cached_property

from torch.distributed.device_mesh import init_device_mesh
from torchtitan.logging import logger


@dataclass
class ParallelDims:
dp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
dp_type: str

def __post_init__(self):
self.dp_type = self.dp_type.lower()
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
):
if d > 1:
dims.append(d)
names.append(name)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)

@property
def dp_enabled(self):
return self.dp > 1

@property
def tp_enabled(self):
return self.tp > 1

@property
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.tp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.tp * self.pp
Loading

0 comments on commit 8849580

Please sign in to comment.