-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add model MegatronBert #1678
Add model MegatronBert #1678
Changes from all commits
5e3a0cb
38d1410
d41c5b2
061afce
b743422
52b1658
92bef9b
e7b9c3c
e417f7e
e9cbbc2
71beda9
a70b27d
f2e39e7
0cbd767
4b8ab0b
48184ff
14026f7
e01d195
bd5086d
6f223d5
89a79f3
bbd382e
dd3d35f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# MegatronBert with PaddleNLP | ||
|
||
[Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/pdf/1909.08053.pdf) | ||
|
||
**模型简介:** | ||
近期在语言建模方面的工作表明,训练大型transformers模型提高了自然语言处理应用的技术水平。然而,由于内存限制,非常大的模型可能难以训练。在这项工作中, | ||
作者提出了训练大型transformers模型的技术,并实现了一种简单、高效的模型运算并行方法,该方法能够训练具有数十亿个参数的transformers模型。 | ||
|
||
本项目是 MegatronBert 在 Paddle 2.x上的开源实现。 | ||
|
||
## 快速开始 | ||
|
||
### 下游任务微调 | ||
|
||
#### 1、SQuAD1.1 & SQuAD2.0 | ||
SQuAD1.1数据集 | ||
|
||
```shell | ||
python -m paddle.distributed.launch run_squad.py \ | ||
--do_train \ | ||
--do_predict \ | ||
--batch_size=8 \ | ||
--model_name_or_path=megatronbert-cased | ||
--learning_rate=1e-5 \ | ||
--output_dir=output/ \ | ||
--device=gpu \ | ||
--num_train_epochs=2 | ||
``` | ||
其中参数释义如下: | ||
- `model_name_or_path` 指示了模型类型,当前支持`megatronbert-cased`和`megatronbert-uncased`模型。 | ||
- `batch_size` 表示每次迭代**每张卡**上的样本数目。 | ||
- `learning_rate` 表示基础学习率大小,将于learning rate scheduler产生的值相乘作为当前学习率。 | ||
- `output_dir` 表示模型保存路径。 | ||
- `device` 表示使用的设备类型。默认为GPU,可以配置为CPU、GPU、XPU。若希望使用多GPU训练,将其设置为GPU,同时环境变量CUDA_VISIBLE_DEVICES配置要使用的GPU id。 | ||
- `num_train_epochs` 表示需要训练的epoch数量 | ||
|
||
训练结束后模型会对模型进行评估,其评估在验证集上完成, 训练完成后你将看到如下结果: | ||
```text | ||
{ | ||
"exact": 88.78902554399243, | ||
"f1": 94.4082803514958, | ||
"total": 10570, | ||
"HasAns_exact": 88.78902554399244, | ||
"HasAns_f1": 94.4082803514958, | ||
"HasAns_total": 10570 | ||
} | ||
``` | ||
|
||
SQuAD2.0数据集 | ||
```shell | ||
python -m paddle.distributed.launch run_squad.py \ | ||
--do_train \ | ||
--version_2_with_negative \ | ||
--do_predict \ | ||
--batch_size=8 \ | ||
--model_name_or_path=megatronbert-cased | ||
--learning_rate=1e-5 \ | ||
--output_dir=output/ \ | ||
--device=gpu \ | ||
--num_train_epochs=2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
``` | ||
|
||
其中参数释义如下: | ||
- `version_2_with_negative` 是否使用SQuAD2.0数据集 | ||
|
||
训练结束后模型会对模型进行评估,其评估在验证集上完成, 训练完成后你将看到如下结果: | ||
```text | ||
{ | ||
"exact": 85.85867093405206, | ||
"f1": 88.70579950475263, | ||
"total": 11873, | ||
"HasAns_exact": 82.47300944669365, | ||
"HasAns_f1": 88.17543143048748, | ||
"HasAns_total": 5928, | ||
"NoAns_exact": 89.23465096719933, | ||
"NoAns_f1": 89.23465096719933, | ||
"NoAns_total": 5945, | ||
"best_exact": 85.99343047250063, | ||
"best_exact_thresh": -1.6154582500457764, | ||
"best_f1": 88.75296534320918, | ||
"best_f1_thresh": -0.20494508743286133 | ||
} | ||
``` | ||
|
||
#### 2、mnli数据集 | ||
|
||
```shell | ||
python -m paddle.distributed.launch run_glue.py \ | ||
--task_name=mnli \ | ||
--output_dir=output/ \ | ||
--model_name_or_path=megatronbert-cased \ | ||
--learning_rate=1e-5 \ | ||
--device=gpu \ | ||
--num_train_epochs=2 | ||
``` | ||
训练结束后模型会对模型进行评估,其评估在测试集上完成, 训练完成后你将看到如下结果: | ||
```text | ||
eval loss: 0.186327, acc: 0.8992358634742741, eval loss: 0.332409, acc: 0.8968673718470301, eval done total : 118.65499472618103 s | ||
``` | ||
|
||
# Reference | ||
|
||
* [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/pdf/1909.08053.pdf) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# Copyright 2018 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description=__doc__) | ||
parser.add_argument( | ||
"--train_file", | ||
type=str, | ||
required=False, | ||
default=None, | ||
help="Train data path.") | ||
parser.add_argument( | ||
"--predict_file", | ||
type=str, | ||
required=False, | ||
default=None, | ||
help="Predict data path.") | ||
parser.add_argument( | ||
"--model_type", | ||
default="megatronbert", | ||
type=str, | ||
help="Type of pre-trained model.") | ||
parser.add_argument( | ||
"--model_name_or_path", | ||
default="megatronbert-cased", | ||
type=str, | ||
help="Path to pre-trained model or shortcut name of model.") | ||
parser.add_argument( | ||
"--output_dir", | ||
default=None, | ||
type=str, | ||
help="The output directory where the model predictions and checkpoints will be written. " | ||
"Default as `outputs`") | ||
parser.add_argument( | ||
"--max_seq_length", | ||
default=512, | ||
type=int, | ||
help="The maximum total input sequence length after tokenization. Sequences longer " | ||
"than this will be truncated, sequences shorter will be padded.") | ||
parser.add_argument( | ||
"--batch_size", | ||
default=8, | ||
type=int, | ||
help="Batch size per GPU/CPU for training.") | ||
parser.add_argument( | ||
"--learning_rate", | ||
default=1e-5, | ||
type=float, | ||
help="The initial learning rate for Adam.") | ||
parser.add_argument( | ||
"--weight_decay", | ||
default=0.01, | ||
type=float, | ||
help="Weight decay if we apply some.") | ||
parser.add_argument( | ||
"--adam_epsilon", | ||
default=1e-8, | ||
type=float, | ||
help="Epsilon for Adam optimizer.") | ||
parser.add_argument( | ||
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | ||
parser.add_argument( | ||
"--num_train_epochs", | ||
default=2, | ||
type=int, | ||
help="Total number of training epochs to perform.") | ||
parser.add_argument( | ||
"--max_steps", | ||
default=-1, | ||
type=int, | ||
help="If > 0: set total number of training steps to perform. Override num_train_epochs." | ||
) | ||
parser.add_argument( | ||
"--warmup_proportion", | ||
default=0.06, | ||
type=float, | ||
help="Proportion of training steps to perform linear learning rate warmup for." | ||
) | ||
parser.add_argument( | ||
"--logging_steps", | ||
type=int, | ||
default=100, | ||
help="Log every X updates steps.") | ||
parser.add_argument( | ||
"--save_steps", | ||
type=int, | ||
default=5000, | ||
help="Save checkpoint every X updates steps.") | ||
parser.add_argument( | ||
"--seed", type=int, default=42, help="random seed for initialization") | ||
parser.add_argument( | ||
'--device', | ||
choices=['cpu', 'gpu'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里没有xpu,上面readme xpu的叙述删掉吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
default="gpu", | ||
help="Select which device to train model, defaults to gpu.") | ||
parser.add_argument( | ||
"--doc_stride", | ||
type=int, | ||
default=128, | ||
help="When splitting up a long document into chunks, how much stride to take between chunks." | ||
) | ||
parser.add_argument( | ||
"--n_best_size", | ||
type=int, | ||
default=20, | ||
help="The total number of n-best predictions to generate in the nbest_predictions.json output file." | ||
) | ||
parser.add_argument( | ||
"--null_score_diff_threshold", | ||
type=float, | ||
default=0.0, | ||
help="If null_score - best_non_null is greater than the threshold predict null." | ||
) | ||
parser.add_argument( | ||
"--max_query_length", type=int, default=64, help="Max query length.") | ||
parser.add_argument( | ||
"--max_answer_length", type=int, default=30, help="Max answer length.") | ||
parser.add_argument( | ||
"--do_lower_case", | ||
action='store_false', | ||
help="Whether to lower case the input text. Should be True for uncased models and False for cased models." | ||
) | ||
parser.add_argument( | ||
"--verbose", action='store_true', help="Whether to output verbose log.") | ||
parser.add_argument( | ||
"--version_2_with_negative", | ||
action='store_true', | ||
help="If true, the SQuAD examples contain some that do not have an answer. If using squad v2.0, it should be set true." | ||
) | ||
parser.add_argument( | ||
"--do_train", action='store_true', help="Whether to train the model.") | ||
parser.add_argument( | ||
"--do_predict", action='store_true', help="Whether to predict.") | ||
args = parser.parse_args() | ||
return args |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shell
\
有问题There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done