diff --git a/test_runner.py b/test_runner.py index 6e2fb37a..8951e6ef 100755 --- a/test_runner.py +++ b/test_runner.py @@ -141,7 +141,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 4", - "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", "--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B", ], ], @@ -155,7 +154,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1F1B", "--training.data_parallel_shard_degree 1", ], @@ -170,7 +168,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule Gpipe", "--training.data_parallel_shard_degree 1", ], @@ -185,7 +182,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule 1F1B", "--training.data_parallel_shard_degree 2", ], @@ -199,7 +195,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--experimental.pipeline_parallel_schedule Gpipe", "--training.data_parallel_shard_degree 2", ], @@ -213,7 +208,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--training.tensor_parallel_degree 2", ], ], @@ -226,7 +220,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], @@ -234,7 +227,6 @@ def build_test_list(): "--training.steps 20", "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", ], @@ -248,7 +240,6 @@ def build_test_list(): [ [ "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", "--training.data_parallel_shard_degree 2", "--training.tensor_parallel_degree 2", "--training.compile", @@ -264,7 +255,6 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 4", - "--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7", "--experimental.pipeline_parallel_schedule Interleaved1F1B", ], ], diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index 7e12aea6..fa40e548 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -13,6 +13,11 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, +) from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger @@ -141,6 +146,30 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal ) return stage, model + # if split points are not specified, we split the model into equal chunks based on + # the number of pipeline stages. + if len(splits) == 0: + # we assume num_stages per rank based on the schedule time + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + if issubclass(schedule_class, PipelineScheduleSingle): + num_stages_per_rank = 1 + elif issubclass(schedule_class, PipelineScheduleMulti): + num_stages_per_rank = 2 + else: + raise ValueError( + f"Unsupported pipeline schedule: {job_config.experimental.pipeline_parallel_schedule}" + ) + total_stages = parallel_dims.pp * num_stages_per_rank + num_layers = model_config.n_layers + if total_stages > num_layers: + raise ValueError("Total stages cannot be greater than the number of layers") + interval = num_layers // total_stages + # Generate split points + splits = ["layers." + str(i * interval) for i in range(1, total_stages)] + print(splits) + num_stages = len(splits) + 1 stage_idx = pp_rank