Skip to content

Commit

Permalink
Retrieve schedules from get_schedule_class()
Browse files Browse the repository at this point in the history
ghstack-source-id: f9cf14ce983933c0ef008960f7be24e42e6ddbf9
Pull Request resolved: #595
  • Loading branch information
H-Huang committed Oct 8, 2024
1 parent 25ec560 commit 334a8dd
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 32 deletions.
24 changes: 12 additions & 12 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B",
],
],
"PP looped flexible 1f1b test",
"PP looped flexible 1F1B test",
"pp_looped_flexible_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
Expand All @@ -156,11 +156,11 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test 1f1b",
"PP 1D test 1F1B",
"pp_1f1b",
requires_seed_checkpoint=True,
ngpu=2,
Expand All @@ -171,11 +171,11 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--experimental.pipeline_parallel_schedule Gpipe",
"--training.data_parallel_shard_degree 1",
],
],
"PP 1D test gpipe",
"PP 1D test Gpipe",
"pp_gpipe",
requires_seed_checkpoint=True,
ngpu=2,
Expand All @@ -186,11 +186,11 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--experimental.pipeline_parallel_schedule 1F1B",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP 1f1b 2D test",
"PP+DP 1F1B 2D test",
"pp_dp_1f1b",
requires_seed_checkpoint=True,
),
Expand All @@ -200,11 +200,11 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--experimental.pipeline_parallel_schedule Gpipe",
"--training.data_parallel_shard_degree 2",
],
],
"PP+DP gpipe 2D test",
"PP+DP Gpipe 2D test",
"pp_dp_gpipe",
requires_seed_checkpoint=True,
),
Expand Down Expand Up @@ -265,10 +265,10 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--experimental.pipeline_parallel_schedule Interleaved1F1B",
],
],
"PP looped 1f1b test",
"PP looped 1F1B test",
"pp_looped_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
Expand Down
6 changes: 3 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,14 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"],
default="1f1b",
choices=["1F1B", "Gpipe", "Interleaved1F1B", "FlexibleInterleaved1F1B"],
default="1F1B",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
Expand Down
21 changes: 4 additions & 17 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,21 @@
from typing import Tuple

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
from torch.distributed.pipelining.schedules import get_schedule_class
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
schedule_class = Schedule1F1B
elif job_config.experimental.pipeline_parallel_schedule == "gpipe":
schedule_class = ScheduleGPipe
elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b":
schedule_class = ScheduleInterleaved1F1B
looped_schedule = True
elif (
schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
== "flexible_interleaved_1f1b"
):
schedule_class = ScheduleFlexibleInterleaved1F1B
)
if schedule_class in [ScheduleInterleaved1F1B, ScheduleFlexibleInterleaved1F1B]:
looped_schedule = True
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"
)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
Expand Down

0 comments on commit 334a8dd

Please sign in to comment.