Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RingFlashAttention for context parallel #8383

Merged
merged 7 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions csrc/generation/flash_attn_bwd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"
#include <iostream>
#include <vector>

using paddle::Tensor;

namespace paddle {
namespace experimental {

PADDLE_API void flash_attn_grad(const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& out,
const Tensor& softmax_lse,
const Tensor& seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor& out_grad,
float dropout,
bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad);

}
} // namespace paddle



std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &out,
const Tensor &softmax_lse,
const Tensor &seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor &out_grad,
float dropout,
bool causal);


std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &out,
const Tensor &softmax_lse,
const Tensor &seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor &out_grad,
float dropout,
bool causal){
std::vector<Tensor> res(3);
paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask,
out_grad, dropout, causal, &res[0], &res[1],
&res[2]);
return res;
}



std::vector<paddle::DataType> SRFlashAttnBwdDtype(paddle::DataType q_dtype,
paddle::DataType k_dtype,
paddle::DataType v_dtype) {
return {q_dtype, k_dtype, v_dtype};

}


std::vector<std::vector<int64_t>> SRFlashAttnBwdInferShape(
std::vector<int64_t> q_shape, std::vector<int64_t> k_shape,
std::vector<int64_t> v_shape) {
return {q_shape, k_shape, v_shape};
}


PD_BUILD_OP(flash_attn_bwd)
.Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"})
.Outputs({"q_grad", "k_grad", "v_grad"})
.Attrs({"dropout: float", "causal: bool"})
.SetKernelFn(PD_KERNEL(SRFlashAttnBwd))
.SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype));
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_gencode_flags():
"./generation/step.cu",
"./generation/quant_int8.cu",
"./generation/dequant_int8.cu",
"./generation/flash_attn_bwd.cc",
],
extra_compile_args={
"cxx": ["-O3"],
Expand Down
10 changes: 9 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.

--context_parallel_degree
上下文并行是将训练数据在序列维度进行切分的并行方法。
该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。
默认值-1, 表示不启用上下文并行,
(`int`, 可选, 默认为 `-1`)
(注: 该方法需要修改模型结构, 目前支持LLAMA)
(注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用)
Context parallelism is a parallel method that segments training data in the sequence dimension.
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
--recompute
是否使用重计算训练。可以节省显存。
重新计算前向过程以获取梯度,减少中间变量显存.
Expand Down
84 changes: 84 additions & 0 deletions llm/llama/run_trainer_tp2cp2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


set -x
unset CUDA_VISIBLE_DEVICES

rm -rf log
rm -rf output

unset PADDLE_ELASTIC_JOB_ID
unset PADDLE_TRAINER_ENDPOINTS
unset DISTRIBUTED_TRAINER_ENDPOINTS
unset FLAGS_START_PORT
unset PADDLE_ELASTIC_TIMEOUT

# export FLAGS_embedding_deterministic=1
# export FLAGS_cudnn_deterministic=1
# export FLAGS_flash_attn_version=v1
# export USE_FAST_LN=0


max_seq_length=1024

max_steps=1000
log_dir=seq_${max_seq_length}_log
echo "log_dir:${log_dir}"
rm -rf $log_dir

export PYTHONPATH=../../:$PYTHONPATH
python -u -m paddle.distributed.launch \
--gpus "3,4,5,7" \
--log_dir "./$log_dir" \
run_pretrain.py \
--model_name_or_path "facebook/llama-7b" \
--tokenizer_name_or_path "facebook/llama-7b" \
--input_dir "./data" \
--output_dir "./output" \
--split 949,50,1 \
--max_seq_length $max_seq_length \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 4 \
--per_device_eval_batch_size 4 \
--bf16 \
--fp16_opt_level "O2" \
--use_flash_attention 1 \
--virtual_pp_degree 1 \
--pp_recompute_interval 1 \
--learning_rate 0.00001 \
--min_learning_rate 0.000001 \
--max_steps $max_steps \
--weight_decay 0.01 \
--warmup_ratio 0.01 \
--max_grad_norm 1.0 \
--logging_steps 1 \
--dataloader_num_workers 1 \
--eval_steps 1001 \
--disable_tqdm true \
--continue_training 0 \
--do_train \
--device "gpu" \
--enable_linear_fused_grad_add false \
--recompute_use_reentrant true \
--data_cache "./data_cache" \
--pipeline_parallel_degree 1 \
--context_parallel_degree 2 \
--tensor_parallel_degree 2 \
--sequence_parallel false \
--skip_profile_timer true \
--amp_master_grad \
--report_to "visualdl" \
--logging_dir "./visualdl_log" \
--save_steps 2000000 \
4 changes: 4 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,15 @@ def main():
config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob

config.sep_parallel_degree = training_args.sep_parallel_degree
config.context_parallel_degree = training_args.context_parallel_degree
if config.sequence_parallel:
assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel."
assert (
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"
assert (
config.seq_length % config.context_parallel_degree == 0
), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}"

if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1:
try:
Expand Down
9 changes: 8 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None
from ..transformers.context_parallel_utils import split_inputs_sequence_dim_load_balance
from ..transformers.model_utils import (
PretrainedModel,
_add_variant,
Expand Down Expand Up @@ -763,6 +764,8 @@
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size
if self.args.sep_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.sep_parallel_degree
if self.args.context_parallel_degree > 0:
trainable_numel = trainable_numel // self.args.context_parallel_degree

Check warning on line 768 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L767-L768

Added lines #L767 - L768 were not covered by tests
# the numel is roughly, because the tensor parallel still hold own bias or layer_norm weight without splited
# so, the trainable numel is a little bigger than real.
logger.debug(f" Number of trainable parameters = {trainable_numel:,} (all devices, roughly)")
Expand Down Expand Up @@ -897,6 +900,8 @@
for step, inputs in enumerate(epoch_iterator):
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1:
inputs = split_inputs_sequence_dim(inputs)
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

em,是不是 开了 cp 的话,相当于是 一路数据流,现在对应多份完整参数了?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cp->2
tp->2

4card。 两份参数,1路数据流

for- back
两份参数 -> grad ? grad sum?

inputs = split_inputs_sequence_dim_load_balance(inputs)

Check warning on line 904 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L903-L904

Added lines #L903 - L904 were not covered by tests
self.timers and self.timers("read-data").stop()
os.environ["TRAINER_GLOBAL_STEP"] = str(self.state.global_step)
self.callback_handler.on_load_data_end(args, self.state, self.control, inputs=inputs)
Expand Down Expand Up @@ -1760,6 +1765,7 @@
in_sharding_parallel_mode = self.sharding is not None
in_tensor_parallel_mode = self.args.tensor_parallel_degree > 1
in_sep_parallel_mode = self.args.sep_parallel_degree > 1
in_cp_parallel_mode = self.args.context_parallel_degree > 1

# Multi-gpu training
if (
Expand All @@ -1770,6 +1776,7 @@
or in_sharding_parallel_mode
or in_tensor_parallel_mode
or in_sep_parallel_mode
or in_cp_parallel_mode
)
):
model = paddle.DataParallel(model)
Expand Down Expand Up @@ -1897,7 +1904,7 @@
if (
not in_pipeline_parallel_mode
and not in_sharding_parallel_mode
and (in_tensor_parallel_mode or in_sep_parallel_mode)
and (in_tensor_parallel_mode or in_sep_parallel_mode or in_cp_parallel_mode)
):
if self.args.amp_master_grad:
mix_precision_utils.MixPrecisionLayer(model, dtype=self.amp_dtype) # return value has no use
Expand Down
43 changes: 39 additions & 4 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy.
)
context_parallel_degree (`int`, *optional*, defaults to `-1`)(
Context parallelism is a parallel method that segments training data in the sequence dimension.
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
)
data_parallel_config (`str`, *optional*)(
Some additional configs which affect data parallel performance, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -583,6 +587,15 @@
)
},
)
context_parallel_degree: int = field(
default=-1,
metadata={
"help": (
"The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to "
"data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. "
)
},
)
data_parallel_config: str = field(
default="",
metadata={
Expand Down Expand Up @@ -918,16 +931,24 @@
if world_size > 1:
tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
sep_parallel_degree = max(self.sep_parallel_degree, 1)
context_parallel_degree = max(self.context_parallel_degree, 1)

Check warning on line 934 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L934

Added line #L934 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我再问一下,context parellel 和 seq parallel是不是互斥的,需不需要加一个判断?还是可以一起用

pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

assert (
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}."

assert not (
sep_parallel_degree > 1 and context_parallel_degree > 1
), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}."

Check warning on line 943 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L941-L943

Added lines #L941 - L943 were not covered by tests

if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = world_size // (
tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
tensor_parallel_degree
* sep_parallel_degree

Check warning on line 949 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L949

Added line #L949 was not covered by tests
* context_parallel_degree
* pipeline_parallel_degree
)

sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
Expand All @@ -936,27 +957,34 @@
self.sharding = []

self.data_parallel_degree = world_size // (
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree
sharding_parallel_degree
* tensor_parallel_degree
* sep_parallel_degree
* context_parallel_degree
* pipeline_parallel_degree

Check warning on line 964 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L962-L964

Added lines #L962 - L964 were not covered by tests
)

if (
sharding_parallel_degree > 1
or tensor_parallel_degree > 1
or pipeline_parallel_degree > 1
or self.sep_parallel_degree > 1
or self.context_parallel_degree > 1
):
self.use_hybrid_parallel = True
self.sharding_parallel_degree = sharding_parallel_degree
self.tensor_parallel_degree = tensor_parallel_degree
self.pipeline_parallel_degree = pipeline_parallel_degree
self.sep_parallel_degree = sep_parallel_degree
self.context_parallel_degree = context_parallel_degree

if not self.use_hybrid_parallel:
self.sharding = []
self.sharding_parallel_degree = -1
self.tensor_parallel_degree = -1
self.pipeline_parallel_degree = -1
self.sep_parallel_degree = -1
self.context_parallel_degree = -1

Check warning on line 987 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L987

Added line #L987 was not covered by tests

if self.hybrid_parallel_topo_order is None:
self.hybrid_parallel_topo_order = "pp_first"
Expand Down Expand Up @@ -1157,7 +1185,9 @@
"mp_degree": self.tensor_parallel_degree,
"pp_degree": self.pipeline_parallel_degree,
"sharding_degree": self.sharding_parallel_degree,
"sep_degree": self.sep_parallel_degree,
"sep_degree": self.sep_parallel_degree
if self.sep_parallel_degree > 1
else self.context_parallel_degree,

Check warning on line 1190 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1188-L1190

Added lines #L1188 - L1190 were not covered by tests
"order": order,
}
else:
Expand Down Expand Up @@ -1241,6 +1271,7 @@
elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
self.context_parallel_degree = max(self.context_parallel_degree, 1)
self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1)

assert (
Expand All @@ -1250,7 +1281,10 @@
if self.sharding_parallel_degree == -1:
if len(self.sharding) > 0:
self.sharding_parallel_degree = world_size // (
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree
self.tensor_parallel_degree
* self.sep_parallel_degree

Check warning on line 1285 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1285

Added line #L1285 was not covered by tests
* self.context_parallel_degree
* self.pipeline_parallel_degree
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

保存相关的考虑了吗?通信组需要额外建吗?

)

self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1)
Expand All @@ -1262,6 +1296,7 @@
self.sharding_parallel_degree
* self.tensor_parallel_degree
* self.sep_parallel_degree
* self.context_parallel_degree

Check warning on line 1299 in paddlenlp/trainer/training_args.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/training_args.py#L1299

Added line #L1299 was not covered by tests
* self.pipeline_parallel_degree
)

Expand Down
Loading
Loading