-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
adding first version of bart code release (#902)
Summary: This is the first version of BART code / model release. It still requires lot of clean up, instructions, making sure results are reproducible before we can release it. Pull Request resolved: fairinternal/fairseq-py#902 Differential Revision: D18389535 fbshipit-source-id: 77f16800307ce831bd29538fdd34800793210f46
- Loading branch information
1 parent
e98bf7e
commit a92bcda
Showing
18 changed files
with
1,360 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Fine-tuning BART on GLUE tasks | ||
|
||
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands: | ||
```bash | ||
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py | ||
python download_glue_data.py --data_dir glue_data --tasks all | ||
``` | ||
|
||
### 2) Preprocess GLUE task data (same as RoBERTa): | ||
```bash | ||
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name> | ||
``` | ||
`glue_task_name` is one of the following: | ||
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}` | ||
Use `ALL` for preprocessing all the glue tasks. | ||
|
||
### 3) Fine-tuning on GLUE task: | ||
Example fine-tuning cmd for `RTE` task | ||
```bash | ||
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16 | ||
WARMUP_UPDATES=61 # 6 percent of the number of updates | ||
LR=1e-05 # Peak LR for polynomial LR scheduler. | ||
NUM_CLASSES=2 | ||
MAX_SENTENCES=16 # Batch size. | ||
BART_PATH=/path/to/bart/model.pt | ||
|
||
CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \ | ||
--restore-file $BART_PATH \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--max-tokens 4400 \ | ||
--task sentence_prediction \ | ||
--add-prev-output-tokens \ | ||
--layernorm-embedding \ | ||
--share-all-embeddings \ | ||
--share-decoder-input-output-embed \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--required-batch-size-multiple 1 \ | ||
--init-token 0 \ | ||
--arch bart_large \ | ||
--criterion sentence_prediction \ | ||
--num-classes $NUM_CLASSES \ | ||
--dropout 0.1 --attention-dropout 0.1 \ | ||
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \ | ||
--clip-norm 0.0 \ | ||
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ | ||
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ | ||
--max-epoch 10 \ | ||
--find-unused-parameters \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; | ||
``` | ||
|
||
For each of the GLUE task, you will need to use following cmd-line arguments: | ||
|
||
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B | ||
---|---|---|---|---|---|---|---|--- | ||
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 | ||
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5 | ||
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32 | ||
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799 | ||
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107 | ||
|
||
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`. | ||
|
||
**Note:** | ||
|
||
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task. | ||
|
||
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. | ||
|
||
### Inference on GLUE task | ||
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: | ||
|
||
```python | ||
from fairseq.models.bart import BARTModel | ||
|
||
bart = BARTModel.from_pretrained( | ||
'checkpoints/', | ||
checkpoint_file='checkpoint_best.pt', | ||
data_name_or_path='RTE-bin' | ||
) | ||
|
||
label_fn = lambda label: bart.task.label_dictionary.string( | ||
[label + bart.task.label_dictionary.nspecial] | ||
) | ||
ncorrect, nsamples = 0, 0 | ||
bart.cuda() | ||
bart.eval() | ||
with open('glue_data/RTE/dev.tsv') as fin: | ||
fin.readline() | ||
for index, line in enumerate(fin): | ||
tokens = line.strip().split('\t') | ||
sent1, sent2, target = tokens[1], tokens[2], tokens[3] | ||
tokens = bart.encode(sent1, sent2) | ||
prediction = bart.predict('sentence_classification_head', tokens).argmax().item() | ||
prediction_label = label_fn(prediction) | ||
ncorrect += int(prediction_label == target) | ||
nsamples += 1 | ||
print('| Accuracy: ', float(ncorrect)/float(nsamples)) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension | ||
|
||
[https://arxiv.org/pdf/1910.13461.pdf] | ||
|
||
## Introduction | ||
|
||
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details. | ||
|
||
## Pre-trained models | ||
|
||
Model | Description | # params | Download | ||
---|---|---|--- | ||
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) | ||
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz) | ||
|
||
## Results | ||
|
||
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)** | ||
_(dev set, single model, single-task finetuning)_ | ||
|
||
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B | ||
---|---|---|---|---|---|---|---|--- | ||
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4 | ||
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2 | ||
|
||
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)** | ||
_(dev set, no additional data used)_ | ||
|
||
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1 | ||
---|---|--- | ||
`roberta.large` | 88.9/94.6 | 86.5/89.4 | ||
`bart.large` | 88.8/94.6 | 86.1/89.2 | ||
|
||
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)** | ||
_(dev set, no additional data used)_ | ||
|
||
Model | R1 | R2 | RL | ||
---|---|---|--- | ||
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18 | ||
`bart.large` | 44.16 | 21.28 | 40.90 | ||
|
||
## Example usage | ||
|
||
##### Load BART from torch.hub (PyTorch >= 1.1): | ||
```python | ||
import torch | ||
bart = torch.hub.load('pytorch/fairseq', 'bart.large') | ||
bart.eval() # disable dropout (or leave in train mode to finetune) | ||
``` | ||
|
||
##### Load BART (for PyTorch 1.0 or custom models): | ||
```python | ||
# Download bart.large model | ||
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz | ||
tar -xzvf bart.large.tar.gz | ||
|
||
# Load the model in fairseq | ||
from fairseq.models.bart import BARTModel | ||
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt') | ||
bart.eval() # disable dropout (or leave in train mode to finetune) | ||
``` | ||
|
||
##### Apply Byte-Pair Encoding (BPE) to input text: | ||
```python | ||
tokens = bart.encode('Hello world!') | ||
assert tokens.tolist() == [0, 31414, 232, 328, 2] | ||
bart.decode(tokens) # 'Hello world!' | ||
``` | ||
|
||
##### Extract features from BART: | ||
```python | ||
# Extract the last layer's features | ||
last_layer_features = bart.extract_features(tokens) | ||
assert last_layer_features.size() == torch.Size([1, 5, 1024]) | ||
|
||
# Extract all layer's features from decoder (layer 0 is the embedding layer) | ||
all_layers = bart.extract_features(tokens, return_all_hiddens=True) | ||
assert len(all_layers) == 13 | ||
assert torch.all(all_layers[-1] == last_layer_features) | ||
``` | ||
|
||
##### Use BART for sentence-pair classification tasks: | ||
```python | ||
# Download BART already finetuned for MNLI | ||
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') | ||
bart.eval() # disable dropout for evaluation | ||
|
||
# Encode a pair of sentences and make a prediction | ||
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.') | ||
bart.predict('mnli', tokens).argmax() # 0: contradiction | ||
|
||
# Encode another pair of sentences | ||
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.') | ||
bart.predict('mnli', tokens).argmax() # 2: entailment | ||
``` | ||
|
||
##### Register a new (randomly initialized) classification head: | ||
```python | ||
bart.register_classification_head('new_task', num_classes=3) | ||
logprobs = bart.predict('new_task', tokens) | ||
``` | ||
|
||
##### Batched prediction: | ||
```python | ||
import torch | ||
from fairseq.data.data_utils import collate_tokens | ||
|
||
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') | ||
bart.eval() | ||
|
||
batch_of_pairs = [ | ||
['BART is a seq2seq model.', 'BART is not sequence to sequence.'], | ||
['BART is denoising autoencoder.', 'BART is version of autoencoder.'], | ||
] | ||
|
||
batch = collate_tokens( | ||
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1 | ||
) | ||
|
||
logprobs = bart.predict('mnli', batch) | ||
print(logprobs.argmax(dim=1)) | ||
# tensor([0, 2]) | ||
``` | ||
|
||
##### Using the GPU: | ||
```python | ||
bart.cuda() | ||
bart.predict('new_task', tokens) | ||
``` | ||
|
||
#### Evaluating the `bart.large.mnli` model: | ||
|
||
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set. | ||
```python | ||
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'} | ||
ncorrect, nsamples = 0, 0 | ||
bart.cuda() | ||
bart.eval() | ||
with open('glue_data/MNLI/dev_matched.tsv') as fin: | ||
fin.readline() | ||
for index, line in enumerate(fin): | ||
tokens = line.strip().split('\t') | ||
sent1, sent2, target = tokens[8], tokens[9], tokens[-1] | ||
tokens = bart.encode(sent1, sent2) | ||
prediction = bart.predict('mnli', tokens).argmax().item() | ||
prediction_label = label_map[prediction] | ||
ncorrect += int(prediction_label == target) | ||
nsamples += 1 | ||
print('| Accuracy: ', float(ncorrect)/float(nsamples)) | ||
# Expected output: 0.9010 | ||
``` | ||
|
||
## Finetuning | ||
|
||
- [Finetuning on GLUE](README.glue.md) | ||
|
||
## Citation | ||
|
||
```bibtex | ||
@article{lewis2019bart, | ||
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural | ||
Language Generation, Translation, and Comprehension}, | ||
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and | ||
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov | ||
and Luke Zettlemoyer }, | ||
journal={arXiv preprint arXiv:1910.13461}, | ||
year = {2019}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from . import BaseWrapperDataset | ||
|
||
|
||
class AppendTokenDataset(BaseWrapperDataset): | ||
|
||
def __init__(self, dataset, token=None): | ||
super().__init__(dataset) | ||
self.token = token | ||
if token is not None: | ||
self._sizes = np.array(dataset.sizes) + 1 | ||
else: | ||
self._sizes = dataset.sizes | ||
|
||
def __getitem__(self, idx): | ||
item = self.dataset[idx] | ||
if self.token is not None: | ||
item = torch.cat([item, item.new([self.token])]) | ||
return item | ||
|
||
@property | ||
def sizes(self): | ||
return self._sizes | ||
|
||
def num_tokens(self, index): | ||
n = self.dataset.num_tokens(index) | ||
if self.token is not None: | ||
n += 1 | ||
return n | ||
|
||
def size(self, index): | ||
n = self.dataset.size(index) | ||
if self.token is not None: | ||
n += 1 | ||
return n |
Oops, something went wrong.