-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RingFlashAttention for context parallel #8383
Merged
wawltor
merged 7 commits into
PaddlePaddle:develop
from
zhangyuqin1998:ring_flash_attention
Jun 5, 2024
Merged
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
63b2be8
Add RingFlashAttention for context parallel
zhangyuqin1998 ab562b7
update, using sep_group
zhangyuqin1998 94943a8
using sep group
zhangyuqin1998 812a13e
fix
zhangyuqin1998 16eaedd
fix
zhangyuqin1998 e7c4b1e
update
zhangyuqin1998 26b7059
fix
zhangyuqin1998 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/extension.h" | ||
#include <iostream> | ||
#include <vector> | ||
|
||
using paddle::Tensor; | ||
|
||
namespace paddle { | ||
namespace experimental { | ||
|
||
PADDLE_API void flash_attn_grad(const Tensor& q, | ||
const Tensor& k, | ||
const Tensor& v, | ||
const Tensor& out, | ||
const Tensor& softmax_lse, | ||
const Tensor& seed_offset, | ||
const paddle::optional<Tensor> &attn_mask, | ||
const Tensor& out_grad, | ||
float dropout, | ||
bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad); | ||
|
||
} | ||
} // namespace paddle | ||
|
||
|
||
|
||
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q, | ||
const Tensor &k, | ||
const Tensor &v, | ||
const Tensor &out, | ||
const Tensor &softmax_lse, | ||
const Tensor &seed_offset, | ||
const paddle::optional<Tensor> &attn_mask, | ||
const Tensor &out_grad, | ||
float dropout, | ||
bool causal); | ||
|
||
|
||
std::vector<Tensor> SRFlashAttnBwd(const Tensor &q, | ||
const Tensor &k, | ||
const Tensor &v, | ||
const Tensor &out, | ||
const Tensor &softmax_lse, | ||
const Tensor &seed_offset, | ||
const paddle::optional<Tensor> &attn_mask, | ||
const Tensor &out_grad, | ||
float dropout, | ||
bool causal){ | ||
std::vector<Tensor> res(3); | ||
paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask, | ||
out_grad, dropout, causal, &res[0], &res[1], | ||
&res[2]); | ||
return res; | ||
} | ||
|
||
|
||
|
||
std::vector<paddle::DataType> SRFlashAttnBwdDtype(paddle::DataType q_dtype, | ||
paddle::DataType k_dtype, | ||
paddle::DataType v_dtype) { | ||
return {q_dtype, k_dtype, v_dtype}; | ||
|
||
} | ||
|
||
|
||
std::vector<std::vector<int64_t>> SRFlashAttnBwdInferShape( | ||
std::vector<int64_t> q_shape, std::vector<int64_t> k_shape, | ||
std::vector<int64_t> v_shape) { | ||
return {q_shape, k_shape, v_shape}; | ||
} | ||
|
||
|
||
PD_BUILD_OP(flash_attn_bwd) | ||
.Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"}) | ||
.Outputs({"q_grad", "k_grad", "v_grad"}) | ||
.Attrs({"dropout: float", "causal: bool"}) | ||
.SetKernelFn(PD_KERNEL(SRFlashAttnBwd)) | ||
.SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape)) | ||
.SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype)); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
set -x | ||
unset CUDA_VISIBLE_DEVICES | ||
|
||
rm -rf log | ||
rm -rf output | ||
|
||
unset PADDLE_ELASTIC_JOB_ID | ||
unset PADDLE_TRAINER_ENDPOINTS | ||
unset DISTRIBUTED_TRAINER_ENDPOINTS | ||
unset FLAGS_START_PORT | ||
unset PADDLE_ELASTIC_TIMEOUT | ||
|
||
# export FLAGS_embedding_deterministic=1 | ||
# export FLAGS_cudnn_deterministic=1 | ||
# export FLAGS_flash_attn_version=v1 | ||
# export USE_FAST_LN=0 | ||
|
||
|
||
max_seq_length=1024 | ||
|
||
max_steps=1000 | ||
log_dir=seq_${max_seq_length}_log | ||
echo "log_dir:${log_dir}" | ||
rm -rf $log_dir | ||
|
||
export PYTHONPATH=../../:$PYTHONPATH | ||
python -u -m paddle.distributed.launch \ | ||
--gpus "3,4,5,7" \ | ||
--log_dir "./$log_dir" \ | ||
run_pretrain.py \ | ||
--model_name_or_path "facebook/llama-7b" \ | ||
--tokenizer_name_or_path "facebook/llama-7b" \ | ||
--input_dir "./data" \ | ||
--output_dir "./output" \ | ||
--split 949,50,1 \ | ||
--max_seq_length $max_seq_length \ | ||
--per_device_train_batch_size 1 \ | ||
--gradient_accumulation_steps 4 \ | ||
--per_device_eval_batch_size 4 \ | ||
--bf16 \ | ||
--fp16_opt_level "O2" \ | ||
--use_flash_attention 1 \ | ||
--virtual_pp_degree 1 \ | ||
--pp_recompute_interval 1 \ | ||
--learning_rate 0.00001 \ | ||
--min_learning_rate 0.000001 \ | ||
--max_steps $max_steps \ | ||
--weight_decay 0.01 \ | ||
--warmup_ratio 0.01 \ | ||
--max_grad_norm 1.0 \ | ||
--logging_steps 1 \ | ||
--dataloader_num_workers 1 \ | ||
--eval_steps 1001 \ | ||
--disable_tqdm true \ | ||
--continue_training 0 \ | ||
--do_train \ | ||
--device "gpu" \ | ||
--enable_linear_fused_grad_add false \ | ||
--recompute_use_reentrant true \ | ||
--data_cache "./data_cache" \ | ||
--pipeline_parallel_degree 1 \ | ||
--context_parallel_degree 2 \ | ||
--tensor_parallel_degree 2 \ | ||
--sequence_parallel false \ | ||
--skip_profile_timer true \ | ||
--amp_master_grad \ | ||
--report_to "visualdl" \ | ||
--logging_dir "./visualdl_log" \ | ||
--save_steps 2000000 \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -230,6 +230,10 @@ | |
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to | ||
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. | ||
) | ||
context_parallel_degree (`int`, *optional*, defaults to `-1`)( | ||
Context parallelism is a parallel method that segments training data in the sequence dimension. | ||
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates. | ||
) | ||
data_parallel_config (`str`, *optional*)( | ||
Some additional configs which affect data parallel performance, we provide some option to config it. | ||
following config is support: | ||
|
@@ -583,6 +587,15 @@ | |
) | ||
}, | ||
) | ||
context_parallel_degree: int = field( | ||
default=-1, | ||
metadata={ | ||
"help": ( | ||
"The paddle context parallel strategy. It can reduce the GPU memory of activation to 1/cp, and it is orthogonal to " | ||
"data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. " | ||
) | ||
}, | ||
) | ||
data_parallel_config: str = field( | ||
default="", | ||
metadata={ | ||
|
@@ -918,16 +931,24 @@ | |
if world_size > 1: | ||
tensor_parallel_degree = max(self.tensor_parallel_degree, 1) | ||
sep_parallel_degree = max(self.sep_parallel_degree, 1) | ||
context_parallel_degree = max(self.context_parallel_degree, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我再问一下,context parellel 和 seq parallel是不是互斥的,需不需要加一个判断?还是可以一起用 |
||
pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) | ||
|
||
assert ( | ||
world_size % (self.tensor_parallel_degree * self.pipeline_parallel_degree) == 0 | ||
), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {self.tensor_parallel_degree} and pipeline_parallel_degree: {self.pipeline_parallel_degree}." | ||
|
||
assert not ( | ||
sep_parallel_degree > 1 and context_parallel_degree > 1 | ||
), f"sep parallel and context parallel cannot be used together, sep_parallel_degree:{sep_parallel_degree}, context_parallel_degree:{context_parallel_degree}." | ||
|
||
if self.sharding_parallel_degree == -1: | ||
if len(self.sharding) > 0: | ||
self.sharding_parallel_degree = world_size // ( | ||
tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree | ||
tensor_parallel_degree | ||
* sep_parallel_degree | ||
* context_parallel_degree | ||
* pipeline_parallel_degree | ||
) | ||
|
||
sharding_parallel_degree = max(self.sharding_parallel_degree, 1) | ||
|
@@ -936,27 +957,34 @@ | |
self.sharding = [] | ||
|
||
self.data_parallel_degree = world_size // ( | ||
sharding_parallel_degree * tensor_parallel_degree * sep_parallel_degree * pipeline_parallel_degree | ||
sharding_parallel_degree | ||
* tensor_parallel_degree | ||
* sep_parallel_degree | ||
* context_parallel_degree | ||
* pipeline_parallel_degree | ||
) | ||
|
||
if ( | ||
sharding_parallel_degree > 1 | ||
or tensor_parallel_degree > 1 | ||
or pipeline_parallel_degree > 1 | ||
or self.sep_parallel_degree > 1 | ||
or self.context_parallel_degree > 1 | ||
): | ||
self.use_hybrid_parallel = True | ||
self.sharding_parallel_degree = sharding_parallel_degree | ||
self.tensor_parallel_degree = tensor_parallel_degree | ||
self.pipeline_parallel_degree = pipeline_parallel_degree | ||
self.sep_parallel_degree = sep_parallel_degree | ||
self.context_parallel_degree = context_parallel_degree | ||
|
||
if not self.use_hybrid_parallel: | ||
self.sharding = [] | ||
self.sharding_parallel_degree = -1 | ||
self.tensor_parallel_degree = -1 | ||
self.pipeline_parallel_degree = -1 | ||
self.sep_parallel_degree = -1 | ||
self.context_parallel_degree = -1 | ||
|
||
if self.hybrid_parallel_topo_order is None: | ||
self.hybrid_parallel_topo_order = "pp_first" | ||
|
@@ -1157,7 +1185,9 @@ | |
"mp_degree": self.tensor_parallel_degree, | ||
"pp_degree": self.pipeline_parallel_degree, | ||
"sharding_degree": self.sharding_parallel_degree, | ||
"sep_degree": self.sep_parallel_degree, | ||
"sep_degree": self.sep_parallel_degree | ||
if self.sep_parallel_degree > 1 | ||
else self.context_parallel_degree, | ||
"order": order, | ||
} | ||
else: | ||
|
@@ -1241,6 +1271,7 @@ | |
elif self.enable_auto_parallel: | ||
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) | ||
self.sep_parallel_degree = max(self.sep_parallel_degree, 1) | ||
self.context_parallel_degree = max(self.context_parallel_degree, 1) | ||
self.pipeline_parallel_degree = max(self.pipeline_parallel_degree, 1) | ||
|
||
assert ( | ||
|
@@ -1250,7 +1281,10 @@ | |
if self.sharding_parallel_degree == -1: | ||
if len(self.sharding) > 0: | ||
self.sharding_parallel_degree = world_size // ( | ||
self.tensor_parallel_degree * self.sep_parallel_degree * self.pipeline_parallel_degree | ||
self.tensor_parallel_degree | ||
* self.sep_parallel_degree | ||
* self.context_parallel_degree | ||
* self.pipeline_parallel_degree | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 保存相关的考虑了吗?通信组需要额外建吗? |
||
) | ||
|
||
self.sharding_parallel_degree = max(self.sharding_parallel_degree, 1) | ||
|
@@ -1262,6 +1296,7 @@ | |
self.sharding_parallel_degree | ||
* self.tensor_parallel_degree | ||
* self.sep_parallel_degree | ||
* self.context_parallel_degree | ||
* self.pipeline_parallel_degree | ||
) | ||
|
||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
em,是不是 开了 cp 的话,相当于是 一路数据流,现在对应多份完整参数了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cp->2
tp->2
4card。 两份参数,1路数据流
for- back
两份参数 -> grad ? grad sum?