Skip to content

Commit

Permalink
[PP] add flexible interleaved 1f1b schedule (#490)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #490

fixes #483

`python test_runner.py ./out --test pp_looped_flexible_1f1b`
  • Loading branch information
H-Huang authored Jul 30, 2024
1 parent d661ceb commit 0322d27
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
15 changes: 15 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@ def build_test_list():
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped flexible 1f1b test",
"pp_looped_flexible_1f1b",
requires_seed_checkpoint=True,
ngpu=4,
),
OverrideDefinitions(
[
[
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe", "interleaved_1f1b"],
choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
Expand All @@ -23,6 +24,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b":
schedule_class = ScheduleInterleaved1F1B
looped_schedule = True
elif (
job_config.experimental.pipeline_parallel_schedule
== "flexible_interleaved_1f1b"
):
schedule_class = ScheduleFlexibleInterleaved1F1B
looped_schedule = True
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"
Expand Down

0 comments on commit 0322d27

Please sign in to comment.