Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Add Sequence Parallel for Static LLaMA #7746

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions llm/llama/auto_parallel/run_auto_sp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# just for debug auto_parallel

set -x
unset CUDA_VISIBLE_DEVICES

task_name="llama_auto_dp2mp2pp2_vpp2_sp"
# rm -rf output/$task_name/ # ckpt is saved in 'output/''
rm -rf "output/$task_name""_log"

export PARALLEL_CROSS_ENTROPY=true
export FLAGS_call_stack_level=2
export PYTHONPATH=../../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--gpus "0,1,2,3,4,5,6,7" \
--log_dir "output/$task_name""_log" \
run_pretrain_auto.py \
--model_type "llama" \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir "output/$task_name" \
--split 949,50,1 \
--max_seq_length 2048 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 8 \
--use_flash_attention 0 \
--use_fused_rms_norm 0 \
--fp16 0 \
--fp16_opt_level "O2" \
--scale_loss 1024 \
--tensor_parallel_degree 2 \
--pipeline_parallel_degree 2 \
--virtual_pp_degree 2 \
--pipeline_schedule_mode "VPP" \
--sharding_parallel_degree 1 \
--sharding "stage2" \
--learning_rate 0.0001 \
--min_learning_rate 0.00001 \
--max_steps 10 \
--save_steps 5000 \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--eval_steps 1000 \
--report_to "visualdl" \
--disable_tqdm true \
--continue_training 0 \
--recompute 1 \
--recompute_granularity full \
--do_train \
--do_eval \
--device "gpu" \
--data_impl "mmap" \
--parallel_mode "auto" \
--sequence_parallel true \

# --resume_from_checkpoint "output/llama_auto_serial/checkpoint-2" \
110 changes: 108 additions & 2 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Optional, Tuple

import paddle
import paddle.distributed as dist
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed import fleet
Expand Down Expand Up @@ -362,10 +363,24 @@
) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism)
# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(

Check warning on line 370 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L367-L370

Added lines #L367 - L370 were not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(

Check warning on line 376 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L376

Added line #L376 was not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Replicate()],
)

if self.fuse_attention_qkv:
target_shape = [0, 0, self.num_heads, 3 * self.head_dim]

fleet.auto.shard_tensor(self.qkv_proj.weight, *get_dist_attr([None, "mp"], self.ipp))

mix_layer = self.qkv_proj(hidden_states)
Expand All @@ -383,6 +398,11 @@
key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape)
value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape)

if self.config.sequence_parallel:
query_states = paddle.transpose(query_states, [1, 0, 2, 3])
key_states = paddle.transpose(key_states, [1, 0, 2, 3])
value_states = paddle.transpose(value_states, [1, 0, 2, 3])

Check warning on line 404 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L401-L404

Added lines #L401 - L404 were not covered by tests

kv_seq_len = key_states.shape[-3]

if past_key_value is not None:
Expand Down Expand Up @@ -459,6 +479,22 @@
fleet.auto.shard_tensor(self.o_proj.weight, *get_dist_attr(["mp", None], self.ipp))
attn_output = self.o_proj(attn_output)

# enter sp region
if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
attn_output = dist.reshard(

Check warning on line 487 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L483-L487

Added lines #L483 - L487 were not covered by tests
attn_output,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)
else:
attn_output = dist.reshard(

Check warning on line 493 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L493

Added line #L493 was not covered by tests
attn_output,
get_mesh(self.ipp),
[dist.Shard(0)],
)
if not output_attentions:
attn_weights = None

Expand Down Expand Up @@ -565,7 +601,39 @@
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)

# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(

Check warning on line 609 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L606-L609

Added lines #L606 - L609 were not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(

Check warning on line 615 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L615

Added line #L615 was not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Replicate()],
)

hidden_states = self.mlp(hidden_states)
# enter sp region
if self.config.sequence_parallel:
mesh = get_mesh(self.ipp)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(

Check warning on line 626 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L623-L626

Added lines #L623 - L626 were not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Shard(1), dist.Shard(0)],
)
else:
hidden_states = dist.reshard(

Check warning on line 632 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L632

Added line #L632 was not covered by tests
hidden_states,
get_mesh(self.ipp),
[dist.Shard(0)],
)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
Expand Down Expand Up @@ -830,6 +898,24 @@
) # [bs, 1, seq_len, seq_len]

hidden_states = inputs_embeds
if self.config.sequence_parallel:

Check warning on line 901 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L901

Added line #L901 was not covered by tests
# [B, S, H] -> [S, B, H]
emb_transpose = fleet.auto.shard_op(paddle.transpose, get_mesh(0))
hidden_states = emb_transpose(hidden_states, [1, 0, 2])

Check warning on line 904 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L903-L904

Added lines #L903 - L904 were not covered by tests
# enter sp region
mesh = get_mesh(0)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(

Check warning on line 908 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L906-L908

Added lines #L906 - L908 were not covered by tests
hidden_states,
get_mesh(0),
[dist.Shard(1), dist.Shard(0)],
)
else:
hidden_states = dist.reshard(

Check warning on line 914 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L914

Added line #L914 was not covered by tests
hidden_states,
get_mesh(0),
[dist.Shard(0)],
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -838,14 +924,18 @@

for idx, (decoder_layer) in enumerate(self.layers):
ipp = decoder_layer.ipp
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp))
if self.config.sequence_parallel:
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["mp", "dp", None], ipp))

Check warning on line 928 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L927-L928

Added lines #L927 - L928 were not covered by tests
else:
fleet.auto.shard_tensor(hidden_states, *get_dist_attr(["dp", None, None], ipp))

Check warning on line 930 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L930

Added line #L930 was not covered by tests
decoder_layer = fleet.auto.shard_op(decoder_layer, get_mesh(ipp))

if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None

has_gradient = not hidden_states.stop_gradient

if (
self.enable_recompute
and idx not in self.no_recompute_layers
Expand Down Expand Up @@ -1107,6 +1197,22 @@
)

hidden_states = outputs[0] # [bs, seq_len, dim]
# enter tp region
if self.config.sequence_parallel:
mesh = get_mesh(-1)
if "dp" in mesh.dim_names:
hidden_states = dist.reshard(

Check warning on line 1204 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1201-L1204

Added lines #L1201 - L1204 were not covered by tests
hidden_states,
get_mesh(-1),
[dist.Shard(1), dist.Replicate()],
)
else:
hidden_states = dist.reshard(

Check warning on line 1210 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1210

Added line #L1210 was not covered by tests
hidden_states,
get_mesh(-1),
[dist.Replicate()],
)
hidden_states = paddle.transpose(hidden_states, [1, 0, 2])

Check warning on line 1215 in paddlenlp/transformers/llama/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling_auto.py#L1215

Added line #L1215 was not covered by tests

# if labels is None,means we need full output, instead of tensor_parallel_output
# tensor_parallel_output is togather with ParallelCrossEntropy
Expand Down