Skip to content

Commit

Permalink
updates here
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
H-Huang committed Jun 24, 2024
1 parent f261913 commit f0bc5fa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
15 changes: 15 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,25 @@ def build_pipeline_schedule(job_config, parallel_dims, stages, loss_fn):
if n_microbatches is None:
n_microbatches = job_config.experimental.pipeline_parallel_degree

if job_config.experimental.pipeline_parallel_schedule == "zb":
stage_index_to_group_rank = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 3,
5: 2,
6: 1,
7: 0,
}
else:
stage_index_to_group_rank = None

schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
stage_index_to_group_rank=stage_index_to_group_rank,
)

if zb_schedule:
Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def loss_fn(pred, labels):
# pipeline parallel forward / backward inside step() call

if job_config.experimental.pipeline_parallel_schedule == "zb":
is_last_stage = pp_mesh.get_local_rank() == 0
with loss_parallel_ctx():
if pp_mesh.get_local_rank() == 0:
losses = []
Expand Down

0 comments on commit f0bc5fa

Please sign in to comment.