Skip to content

Commit

Permalink
Reordered TP parallel plan to follow execution order
Browse files Browse the repository at this point in the history
ghstack-source-id: 4269f33e4cfd9c3d5d176a97be6544d6d87a1602
Pull Request resolved: #445
  • Loading branch information
awgu committed Jul 10, 2024
1 parent 58b46cd commit 3f717c5
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""
Apply tensor parallelism.
"""

tp_mesh = world_mesh["tp"]
(
row_parallel_strategy,
Expand All @@ -341,9 +340,10 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

# 1. Parallelize the first embedding and the last linear proj layer
# 1. Parallelize the embedding and shard its outputs (which are the first
# transformer block's inputs)
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
# 3. Parallelize the final linear output layer
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -352,12 +352,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=Replicate(),
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
),
"norm": SequenceParallel(),
},
)

Expand All @@ -367,6 +367,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
# Examples can be found at https://github.com/pytorch/torchtitan/pull/437
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
Expand All @@ -375,15 +376,14 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention_norm": SequenceParallel(),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"ffn_norm": SequenceParallel(),
}

# Adjust attention module to use the local number of heads
Expand Down

0 comments on commit 3f717c5

Please sign in to comment.