Skip to content

Commit

Permalink
Make pp split points optional
Browse files Browse the repository at this point in the history
ghstack-source-id: ca24cbaab8944cb245931ff9bd6703896a0a91e9
Pull Request resolved: #604
  • Loading branch information
H-Huang committed Oct 8, 2024
1 parent 334a8dd commit 8c46891
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
10 changes: 0 additions & 10 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B",
],
],
Expand All @@ -155,7 +154,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 1",
],
Expand All @@ -170,7 +168,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule Gpipe",
"--training.data_parallel_shard_degree 1",
],
Expand All @@ -185,7 +182,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 2",
],
Expand All @@ -199,7 +195,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule Gpipe",
"--training.data_parallel_shard_degree 2",
],
Expand All @@ -213,7 +208,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.tensor_parallel_degree 2",
],
],
Expand All @@ -226,15 +220,13 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
[
"--training.steps 20",
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
],
Expand All @@ -248,7 +240,6 @@ def build_test_list():
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_shard_degree 2",
"--training.tensor_parallel_degree 2",
"--training.compile",
Expand All @@ -264,7 +255,6 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule Interleaved1F1B",
],
],
Expand Down
29 changes: 29 additions & 0 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage
from torch.distributed.pipelining.schedules import (
get_schedule_class,
PipelineScheduleMulti,
PipelineScheduleSingle,
)

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
Expand Down Expand Up @@ -141,6 +146,30 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
)
return stage, model

# if split points are not specified, we split the model into equal chunks based on
# the number of pipeline stages.
if len(splits) == 0:
# we assume num_stages per rank based on the schedule time
schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
if issubclass(schedule_class, PipelineScheduleSingle):
num_stages_per_rank = 1
elif issubclass(schedule_class, PipelineScheduleMulti):
num_stages_per_rank = 2
else:
raise ValueError(
f"Unsupported pipeline schedule: {job_config.experimental.pipeline_parallel_schedule}"
)
total_stages = parallel_dims.pp * num_stages_per_rank
num_layers = model_config.n_layers
if total_stages > num_layers:
raise ValueError("Total stages cannot be greater than the number of layers")
interval = num_layers // total_stages
# Generate split points
splits = ["layers." + str(i * interval) for i in range(1, total_stages)]
print(splits)

num_stages = len(splits) + 1
stage_idx = pp_rank

Expand Down

0 comments on commit 8c46891

Please sign in to comment.