diff --git a/test_runner.py b/test_runner.py index a8df397c..a1d3bf22 100755 --- a/test_runner.py +++ b/test_runner.py @@ -46,6 +46,21 @@ def build_test_list(): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + "--experimental.pipeline_parallel_degree 4", + "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", + "--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b", + "--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp + ], + ], + "PP looped flexible 1f1b test", + "pp_looped_flexible_1f1b", + requires_seed_checkpoint=True, + ngpu=4, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9a086830..74215f1a 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -275,7 +275,7 @@ def __init__(self): self.parser.add_argument( "--experimental.pipeline_parallel_schedule", type=str, - choices=["1f1b", "gpipe", "interleaved_1f1b"], + choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"], default="1f1b", help=""" Specify the Pipeline Parallel schedule to use. diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index e60b7f51..adf9eb09 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -7,6 +7,7 @@ from torch.distributed.pipelining import ( Schedule1F1B, + ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ) @@ -23,6 +24,12 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn): elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b": schedule_class = ScheduleInterleaved1F1B looped_schedule = True + elif ( + job_config.experimental.pipeline_parallel_schedule + == "flexible_interleaved_1f1b" + ): + schedule_class = ScheduleFlexibleInterleaved1F1B + looped_schedule = True else: raise NotImplementedError( f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"