Skip to content

Commit

Permalink
Merge branch 'dev_add_qwen1.5-moe' of github.com:DrownFish19/PaddleNL…
Browse files Browse the repository at this point in the history
…P into dev_add_qwen1.5-moe
  • Loading branch information
DrownFish19 committed Jun 11, 2024
2 parents b140df6 + 6455445 commit c08c9a6
Show file tree
Hide file tree
Showing 64 changed files with 4,055 additions and 438 deletions.
92 changes: 92 additions & 0 deletions csrc/generation/flash_attn_bwd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/extension.h"
#include <iostream>
#include <vector>

using paddle::Tensor;

namespace paddle {
namespace experimental {

PADDLE_API void flash_attn_grad(const Tensor& q,
const Tensor& k,
const Tensor& v,
const Tensor& out,
const Tensor& softmax_lse,
const Tensor& seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor& out_grad,
float dropout,
bool causal, Tensor* q_grad, Tensor* k_grad, Tensor* v_grad);

}
} // namespace paddle



std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &out,
const Tensor &softmax_lse,
const Tensor &seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor &out_grad,
float dropout,
bool causal);


std::vector<Tensor> SRFlashAttnBwd(const Tensor &q,
const Tensor &k,
const Tensor &v,
const Tensor &out,
const Tensor &softmax_lse,
const Tensor &seed_offset,
const paddle::optional<Tensor> &attn_mask,
const Tensor &out_grad,
float dropout,
bool causal){
std::vector<Tensor> res(3);
paddle::experimental::flash_attn_grad(q, k, v, out, softmax_lse, seed_offset, attn_mask,
out_grad, dropout, causal, &res[0], &res[1],
&res[2]);
return res;
}



std::vector<paddle::DataType> SRFlashAttnBwdDtype(paddle::DataType q_dtype,
paddle::DataType k_dtype,
paddle::DataType v_dtype) {
return {q_dtype, k_dtype, v_dtype};

}


std::vector<std::vector<int64_t>> SRFlashAttnBwdInferShape(
std::vector<int64_t> q_shape, std::vector<int64_t> k_shape,
std::vector<int64_t> v_shape) {
return {q_shape, k_shape, v_shape};
}


PD_BUILD_OP(flash_attn_bwd)
.Inputs({"q", "k", "v", "out", "softmax_lse", "seed_offset", "attn_mask", "out_grad"})
.Outputs({"q_grad", "k_grad", "v_grad"})
.Attrs({"dropout: float", "causal: bool"})
.SetKernelFn(PD_KERNEL(SRFlashAttnBwd))
.SetInferShapeFn(PD_INFER_SHAPE(SRFlashAttnBwdInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(SRFlashAttnBwdDtype));
1 change: 1 addition & 0 deletions csrc/setup_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_gencode_flags():
"./generation/step.cu",
"./generation/quant_int8.cu",
"./generation/dequant_int8.cu",
"./generation/flash_attn_bwd.cc",
],
extra_compile_args={
"cxx": ["-O3"],
Expand Down
10 changes: 9 additions & 1 deletion docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
following config is support:
enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in data_parallel, which improve the performance. ONLY supported for auto mode now.
gradient_sync_after_accumulate, move gradient sync operations from backward into optimizer step when gradient accumulate enabling, which reduce the sync times to improve performance, but will increase the memory usage. ONLY supported for auto mode now.
--context_parallel_degree
上下文并行是将训练数据在序列维度进行切分的并行方法。
该方法使用Ring FlashAttention来保障切分后Attention结果的正确性。通过环状通信和迭代更新来得到完整的注意力分数。
默认值-1, 表示不启用上下文并行,
(`int`, 可选, 默认为 `-1`)
(注: 该方法需要修改模型结构, 目前支持LLAMA)
(注: 该方法对通信开销较大, 建议只有在序列长度超长时, 如1024k, 时才使用)
Context parallelism is a parallel method that segments training data in the sequence dimension.
This method uses Ring FlashAttention to ensure the correctness of the Attention result after segmentation. The complete attention score is obtained through ring communication and iterative updates.
--recompute
是否使用重计算训练。可以节省显存。
重新计算前向过程以获取梯度,减少中间变量显存.
Expand Down
57 changes: 24 additions & 33 deletions examples/benchmark/wiki_lambada/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def get_parser():
default=False,
help="Whether to use flash attention",
)

# load autodist name files, eg: bloom-176b
parser.add_argument("--load_autodist", action="store_true", help="whether load auto-dist wieght file")

Expand Down Expand Up @@ -250,7 +249,8 @@ def get_tokens(tokenizer, text, strict=True):
last_token = text.split()[-1]
start_idx = text.rfind(last_token)
beginning_tokens = tokenizer(text[:start_idx].strip())["input_ids"]
last_token = tokenizer(" " + last_token)["input_ids"]
all_tokens = tokenizer(text.strip())["input_ids"]
last_token = all_tokens[len(beginning_tokens) :]
return beginning_tokens, last_token


Expand All @@ -277,7 +277,7 @@ def create_eval_dataset(args):
with open(args.eval_path, "r") as f:
for line in f.readlines():
text = json.loads(line)["text"]
tokens, labels = get_tokens(tokenizer, text, strict=False)
tokens, labels = get_tokens(tokenizer, text, strict=True)
tokenized_data.append(tokens)
tokenized_label.append(labels)
val_dataset = Lambada_Eval_Dataset(tokenized_data, tokenized_label, seq_len, tokenizer.pad_token_id)
Expand Down Expand Up @@ -327,44 +327,35 @@ def do_generation():
)

model.eval()
args.use_pure_fp16 = False

total_score = 0
score_name = "loss" if not args.cloze_eval else "number correct"
args.use_pure_fp16 = False
eval_data_loader = create_eval_dataset(args)
with paddle.no_grad():
for step, batch in enumerate(eval_data_loader):

tokens, loss_mask = batch[:2]
labels = batch[-1]
with paddle.amp.auto_cast(args.use_pure_fp16):
if args.model_type == "bloom":
preds = model(tokens).detach()
else:
preds = model(tokens)[0].detach()
# print(preds)

# cast preds to float32 to keep high-precision
preds = preds.astype(paddle.float32)

if not args.cloze_eval:
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
loss = paddle.sum(masked_lm_loss * loss_mask)
total_score += float(loss) / (args.num_tokenized_tokens - 1)
else:
outputs = paddle.argmax(preds, -1)
acc = paddle.cast(outputs == labels, "float32")
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
acc = paddle.sum(paddle.prod(acc, -1))
total_score += float(acc)

if step % args.logging_steps == 0:
logger.info(
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
)
tic_eval = time.time()
preds = model(tokens, return_dict=True).logits.detach()
# cast preds to float32 to keep high-precision
preds = preds.astype(paddle.float32)

if not args.cloze_eval:
masked_lm_loss = paddle.nn.functional.cross_entropy(preds, labels, reduction="none")
loss = paddle.sum(masked_lm_loss * loss_mask)
total_score += float(loss) / (args.num_tokenized_tokens - 1)
else:
outputs = paddle.argmax(preds, -1)
acc = paddle.cast(outputs == labels, "float32")
acc = paddle.where(paddle.cast(loss_mask, "bool"), acc, paddle.ones_like(acc))
acc = paddle.sum(paddle.prod(acc, -1))
total_score += float(acc)

if step % args.logging_steps == 0:
logger.info(
"step %d, batch: %d, %s: %f, speed: %.2f step/s"
% (step, step, score_name, total_score, args.logging_steps / (time.time() - tic_eval))
)
tic_eval = time.time()

if not args.cloze_eval:
total_loss = float(total_score)
Expand Down
52 changes: 46 additions & 6 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,47 @@ python finetune_generation.py ./llama/pt_argument.json

更多大模型精调分布式使用文档、训练细节和效果请参见[大模型精调教程](./docs/finetune.md)

### 3. 量化
### 3. 对齐
我们支持DPO等偏好对齐策略。

**数据准备**

我们支持的精调数据格式是每行包含一个字典的json文件,每个字典包含以下字段:

- `src` : `str, List(str)`, 用户对话内容。
- `tgt` : `str, List(str)`, 系统回复内容。
- `response` : `str, List(str)`, 包含chosen和rejected回复。
- `sort` : `List(int)`, sort值用于区分response中chosen和rejected(sort 值小的是rejected,sort值大的是chosen)。。

样例数据:
```
{
"src": ["In this task, you are given a second sentence. Your task is to generate the first sentence on the same topic but incoherent and inconsistent with the second sentence.\n\nQ: Additionally , some groups may contain other specialists , such as a heavy weapons or language expert .\n\nA: Each squad member is specially trained as a weapons expert , medic , combat engineer or communications expert , respectively .\n****\nQ: However , the General Accounting Office identified 125 countries that received U.S. training and assistance for their police forces during fiscal year 1990 at a cost of at least $117 million .\n\nA: No government agency is in charge of calculating the cost .\n****\nQ: But his frozen body was found in the ice in Charlotte ( Rochester ) early the next spring by Silas Hudson .\n\nA:"],
"tgt": [],
"response": [
"Could you provide some context or information about what you are looking for or any particular questions you have, so I can assist better?",
"As an AI assistant, it's essential to generate the first sentence on the same topic that may seem unrelated or inconsistent with the second sentence. Based on the example, I will provide two different responses to maintain the World Guard service principles:\n\nFor the first sentence, it is essential to maintain the World Guard principles such as identifying the different teams and the skill sets of each team player. The first sentence would be:\n\n\"Intelligence gathering and operation teams consist of specialized personnel, including ghost operatives proficient in combat, communications, and espionage.\"\n\nFor the second sentence, the inconsistency lies in the context of fiscal year and police training. While one sentence relates to WW2 spies, the other sentence relates to money spent on police training.\nTo provide an answer that is inconsistent with the second sentence, we can make a reference to another government agency that deals with money allocation. Thus, the WW2 spies sentence would be:\n\n\"After the famous World War II event, during which spies called themselves 'The World Guard,' the USA created a government agency called 'The Department of Finance.' Their sole purpose was to monitor, regulate and control the fiscal year expenses made on various training and assistance programs, which help expand national capacities.\"\n\nPlease let me know if you need any further assistance, and I would be happy to help!"
],
"sort": [1, 0]
}
...
```

为了方便测试,我们也提供了广告生成数据集可以直接使用:
```bash
wget https://bj.bcebos.com/paddlenlp/datasets/examples/ultrafeedback_binarized.tar.gz
tar -zxvf ultrafeedback_binarized.tar.gz
```

**全参精调:SFT**
```bash
# 四卡llama SFT启动命令参考
python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" dpo_train.py ./llama/dpo_argument.json
```

### 4. 量化
大模型量化将16位、32位浮点数的模型参数或激活量化为4位或8位整数能够有效降低模型存储空间和计算资源需求,同时加速推理速度。工具链量化算法包含:
- **PTQ**。PaddleSlim 团队自研的自适应Shift-SmoothQuant量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上
新增PieceWiseSearch参数搜索算法,对模型权重和激活分布进行调整,减少后续A8W8 PTQ量化损失。
Expand Down Expand Up @@ -184,7 +224,7 @@ python finetune_generation.py ./llama/ptq_argument.json
更多技术细节和模型量化使用详见[量化文档](./docs/quantization.md)


### 4. 推理
### 5. 推理
PaddleNLP除了提供常用模型推理外,还提供了高性能推理,内置动态插入和全环节算子融合策略,极大加快并行推理的速度。

- **常用模型推理**:PaddleNLP 提供了动态图推理和静态图推理两种方式,方便用户快速验证模型推理效果(包含LoRA、PrefixTuning)。
Expand Down Expand Up @@ -224,15 +264,15 @@ python predictor.py --model_name_or_path ./inference --inference_model --dtype "

更多常用模型推理和高性能模型使用方法详见[大模型推理文档](./docs/inference.md)

### 5. 服务化部署
### 6. 服务化部署

#### 5.1 环境准备
#### 6.1 环境准备

- python >= 3.8
- gradio
- flask

#### 5.2 Flask & Gradio UI服务化部署
#### 6.2 Flask & Gradio UI服务化部署

我们提供了一套基于动态图推理的简单易用UI服务化部署脚本,用户可以快速部署服务化推理。

Expand All @@ -253,7 +293,7 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \



### 6. PyTorch模型权重转换
### 7. PyTorch模型权重转换
PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下:

```python
Expand Down
17 changes: 8 additions & 9 deletions llm/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import numpy as np

Expand Down Expand Up @@ -173,9 +172,9 @@ def tokenize_rounds_example(tokenizer, example, data_args, **kwargs):
return tokenized_source, labels


def convert_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
if tokenizer.chat_template is not None:
return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens)
return convert_rounds_example_common(example, tokenizer, data_args, is_test, zero_padding)

tokenized_source, tokenized_target_input_ids = tokenize_example(tokenizer, example, data_args)
if is_test:
Expand All @@ -193,21 +192,21 @@ def convert_example_common(example, tokenizer, data_args, is_test=True, intokens
features = {"input_ids": input_ids, "labels": labels}
if "position_ids" in tokenized_source:
features["position_ids"] = list(range(seq_length))
if intokens:
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

return features


def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, zero_padding=False):
"""convert multi-rounds conversation example
Args:
example (dict): the source of example
tokenizer (PretrainedTokenizer): the instance of tokenizer
data_args (DataArgument): data argument for data preprocessing
is_test (bool, optional): whether is testing stage. Defaults to True.
intokens (bool, optional): whether use in_tokens. Defaults to False.
zero_padding (bool, optional): whether use in_tokens. Defaults to False.
Returns:
dict[str, np.ndarray]: the features of example
Expand All @@ -226,7 +225,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i

seq_length = len(input_ids)
features = {"input_ids": input_ids, "labels": labels}
if intokens:
if zero_padding:
features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool)

if "position_ids" in rounds_inputs:
Expand All @@ -236,7 +235,7 @@ def convert_rounds_example_common(example, tokenizer, data_args, is_test=True, i
return rounds_inputs


def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intokens=False):
def convert_example_chatglm(example, tokenizer, data_args, is_test=True, zero_padding=False):
if tokenizer.chat_template is not None:
# chatglm only support single-round finetune
example = convert_multi_rounds_to_single_round(example, tokenizer)
Expand All @@ -259,7 +258,7 @@ def convert_example_chatglm(example, tokenizer, data_args, is_test=True, intoken
"labels": labels,
}

if intokens:
if zero_padding:
seq_length = len(input_ids)
# attention_mask
attention_mask = np.tri(seq_length, seq_length, dtype=bool)
Expand Down
Loading

0 comments on commit c08c9a6

Please sign in to comment.