Skip to content

Commit

Permalink
Renamed parallel styles for transformer block weights
Browse files Browse the repository at this point in the history
ghstack-source-id: cf2b844b37061ea2a2ca1cf4900d2b5a816c3684
Pull Request resolved: #448
  • Loading branch information
awgu committed Jul 10, 2024
1 parent 420f646 commit b79dd78
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
"""

tp_mesh = world_mesh["tp"]
# Parallel styles for transformer block linear weights may be different for
# float8 linears
(
row_parallel_strategy,
col_parallel_strategy,
prepare_module_input,
rowwise_parallel_weight,
colwise_parallel_weight,
prepare_weight_input,
) = get_tp_parallel_strategy(job_config)
loss_parallel = parallel_dims.loss_parallel_enabled

Expand All @@ -318,7 +320,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
output_layouts=Shard(1),
),
"norm": SequenceParallel(),
"output": col_parallel_strategy(
"output": colwise_parallel_weight(
input_layouts=Shard(1),
output_layouts=Shard(-1) if loss_parallel else Replicate(),
use_local_output=not loss_parallel,
Expand All @@ -333,22 +335,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in model.layers.items():
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
"attention": prepare_weight_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
"attention.wq": colwise_parallel_weight(),
"attention.wk": colwise_parallel_weight(),
"attention.wv": colwise_parallel_weight(),
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)),
"ffn_norm": SequenceParallel(),
"feed_forward": prepare_module_input(
"feed_forward": prepare_weight_input(
input_layouts=(Shard(1),),
desired_input_layouts=(Replicate(),),
),
"feed_forward.w1": col_parallel_strategy(),
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
"feed_forward.w3": col_parallel_strategy(),
"feed_forward.w1": colwise_parallel_weight(),
"feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)),
"feed_forward.w3": colwise_parallel_weight(),
}

# Adjust attention module to use the local number of heads
Expand Down

0 comments on commit b79dd78

Please sign in to comment.