Skip to content

Commit

Permalink
adding first version of bart code release (#902)
Browse files Browse the repository at this point in the history
Summary:
This is the first version of BART code / model release.

It still requires lot of clean up, instructions, making sure results are reproducible before we can release it.
Pull Request resolved: fairinternal/fairseq-py#902

Differential Revision: D18389535

fbshipit-source-id: 77f16800307ce831bd29538fdd34800793210f46
  • Loading branch information
Naman Goyal authored and facebook-github-bot committed Nov 9, 2019
1 parent e98bf7e commit a92bcda
Show file tree
Hide file tree
Showing 18 changed files with 1,360 additions and 39 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ modeling and other text generation tasks.

### What's New:

- November 2019: [BART model and code released](examples/bart/README.md)
- November 2019: [XLM-R models and code released](examples/xlmr/README.md)
- September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
- August 2019: [WMT'19 models released](examples/wmt19/README.md)
Expand Down
99 changes: 99 additions & 0 deletions examples/bart/README.glue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Fine-tuning BART on GLUE tasks

### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
```bash
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
python download_glue_data.py --data_dir glue_data --tasks all
```

### 2) Preprocess GLUE task data (same as RoBERTa):
```bash
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
```
`glue_task_name` is one of the following:
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
Use `ALL` for preprocessing all the glue tasks.

### 3) Fine-tuning on GLUE task:
Example fine-tuning cmd for `RTE` task
```bash
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
WARMUP_UPDATES=61 # 6 percent of the number of updates
LR=1e-05 # Peak LR for polynomial LR scheduler.
NUM_CLASSES=2
MAX_SENTENCES=16 # Batch size.
BART_PATH=/path/to/bart/model.pt

CUDA_VISIBLE_DEVICES=0,1 python train.py RTE-bin/ \
--restore-file $BART_PATH \
--max-sentences $MAX_SENTENCES \
--max-tokens 4400 \
--task sentence_prediction \
--add-prev-output-tokens \
--layernorm-embedding \
--share-all-embeddings \
--share-decoder-input-output-embed \
--reset-optimizer --reset-dataloader --reset-meters \
--required-batch-size-multiple 1 \
--init-token 0 \
--arch bart_large \
--criterion sentence_prediction \
--num-classes $NUM_CLASSES \
--dropout 0.1 --attention-dropout 0.1 \
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
--max-epoch 10 \
--find-unused-parameters \
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
```

For each of the GLUE task, you will need to use following cmd-line arguments:

Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107

For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.

**Note:**

a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=32/64/128` depending on the task.

b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`.

### Inference on GLUE task
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:

```python
from fairseq.models.bart import BARTModel

bart = BARTModel.from_pretrained(
'checkpoints/',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='RTE-bin'
)

label_fn = lambda label: bart.task.label_dictionary.string(
[label + bart.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()
with open('glue_data/RTE/dev.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
tokens = bart.encode(sent1, sent2)
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
prediction_label = label_fn(prediction)
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
```
169 changes: 169 additions & 0 deletions examples/bart/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension

[https://arxiv.org/pdf/1910.13461.pdf]

## Introduction

BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) Results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.

## Pre-trained models

Model | Description | # params | Download
---|---|---|---
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)

## Results

**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
_(dev set, single model, single-task finetuning)_

Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
---|---|---|---|---|---|---|---|---
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2

**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
_(dev set, no additional data used)_

Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
---|---|---
`roberta.large` | 88.9/94.6 | 86.5/89.4
`bart.large` | 88.8/94.6 | 86.1/89.2

**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
_(dev set, no additional data used)_

Model | R1 | R2 | RL
---|---|---|---
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
`bart.large` | 44.16 | 21.28 | 40.90

## Example usage

##### Load BART from torch.hub (PyTorch >= 1.1):
```python
import torch
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
bart.eval() # disable dropout (or leave in train mode to finetune)
```

##### Load BART (for PyTorch 1.0 or custom models):
```python
# Download bart.large model
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
tar -xzvf bart.large.tar.gz

# Load the model in fairseq
from fairseq.models.bart import BARTModel
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
bart.eval() # disable dropout (or leave in train mode to finetune)
```

##### Apply Byte-Pair Encoding (BPE) to input text:
```python
tokens = bart.encode('Hello world!')
assert tokens.tolist() == [0, 31414, 232, 328, 2]
bart.decode(tokens) # 'Hello world!'
```

##### Extract features from BART:
```python
# Extract the last layer's features
last_layer_features = bart.extract_features(tokens)
assert last_layer_features.size() == torch.Size([1, 5, 1024])

# Extract all layer's features from decoder (layer 0 is the embedding layer)
all_layers = bart.extract_features(tokens, return_all_hiddens=True)
assert len(all_layers) == 13
assert torch.all(all_layers[-1] == last_layer_features)
```

##### Use BART for sentence-pair classification tasks:
```python
# Download BART already finetuned for MNLI
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
bart.eval() # disable dropout for evaluation

# Encode a pair of sentences and make a prediction
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
bart.predict('mnli', tokens).argmax() # 0: contradiction

# Encode another pair of sentences
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
bart.predict('mnli', tokens).argmax() # 2: entailment
```

##### Register a new (randomly initialized) classification head:
```python
bart.register_classification_head('new_task', num_classes=3)
logprobs = bart.predict('new_task', tokens)
```

##### Batched prediction:
```python
import torch
from fairseq.data.data_utils import collate_tokens

bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
bart.eval()

batch_of_pairs = [
['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
]

batch = collate_tokens(
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
)

logprobs = bart.predict('mnli', batch)
print(logprobs.argmax(dim=1))
# tensor([0, 2])
```

##### Using the GPU:
```python
bart.cuda()
bart.predict('new_task', tokens)
```

#### Evaluating the `bart.large.mnli` model:

Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
```python
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
ncorrect, nsamples = 0, 0
bart.cuda()
bart.eval()
with open('glue_data/MNLI/dev_matched.tsv') as fin:
fin.readline()
for index, line in enumerate(fin):
tokens = line.strip().split('\t')
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
tokens = bart.encode(sent1, sent2)
prediction = bart.predict('mnli', tokens).argmax().item()
prediction_label = label_map[prediction]
ncorrect += int(prediction_label == target)
nsamples += 1
print('| Accuracy: ', float(ncorrect)/float(nsamples))
# Expected output: 0.9010
```

## Finetuning

- [Finetuning on GLUE](README.glue.md)

## Citation

```bibtex
@article{lewis2019bart,
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
Language Generation, Translation, and Comprehension},
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
and Luke Zettlemoyer },
journal={arXiv preprint arXiv:1910.13461},
year = {2019},
}
```
2 changes: 1 addition & 1 deletion examples/roberta/README.glue.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ roberta = RobertaModel.from_pretrained(
)

label_fn = lambda label: roberta.task.label_dictionary.string(
[label + roberta.task.target_dictionary.nspecial]
[label + roberta.task.label_dictionary.nspecial]
)
ncorrect, nsamples = 0, 0
roberta.cuda()
Expand Down
9 changes: 7 additions & 2 deletions fairseq/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from .base_wrapper_dataset import BaseWrapperDataset

from .append_token_dataset import AppendTokenDataset
from .audio.raw_audio_dataset import FileAudioDataset
from .backtranslation_dataset import BacktranslationDataset
from .colorize_dataset import ColorizeDataset
from .concat_dataset import ConcatDataset
from .concat_sentences_dataset import ConcatSentencesDataset
from .denoising_dataset import DenoisingDataset
from .id_dataset import IdDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset
Expand All @@ -33,6 +35,7 @@
from .raw_label_dataset import RawLabelDataset
from .replace_dataset import ReplaceDataset
from .resampling_dataset import ResamplingDataset
from .roll_dataset import RollDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
from .sharded_dataset import ShardedDataset
from .sort_dataset import SortDataset
Expand All @@ -42,7 +45,6 @@
from .transform_eos_dataset import TransformEosDataset
from .transform_eos_lang_pair_dataset import TransformEosLangPairDataset
from .truncate_dataset import TruncateDataset
from .resampling_dataset import ResamplingDataset

from .iterators import (
CountingIterator,
Expand All @@ -52,12 +54,14 @@
)

__all__ = [
'AppendTokenDataset',
'BacktranslationDataset',
'BaseWrapperDataset',
'ColorizeDataset',
'ConcatDataset',
'ConcatSentencesDataset',
'CountingIterator',
'DenoisingDataset',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
Expand All @@ -83,9 +87,10 @@
'PrependDataset',
'PrependTokenDataset',
'ReplaceDataset',
'RollDataset',
'FileAudioDataset',
'RawLabelDataset',
'ResamplingDataset'
'ResamplingDataset',
'RightPadDataset',
'RoundRobinZipDatasets',
'ShardedDataset',
Expand Down
42 changes: 42 additions & 0 deletions fairseq/data/append_token_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch

from . import BaseWrapperDataset


class AppendTokenDataset(BaseWrapperDataset):

def __init__(self, dataset, token=None):
super().__init__(dataset)
self.token = token
if token is not None:
self._sizes = np.array(dataset.sizes) + 1
else:
self._sizes = dataset.sizes

def __getitem__(self, idx):
item = self.dataset[idx]
if self.token is not None:
item = torch.cat([item, item.new([self.token])])
return item

@property
def sizes(self):
return self._sizes

def num_tokens(self, index):
n = self.dataset.num_tokens(index)
if self.token is not None:
n += 1
return n

def size(self, index):
n = self.dataset.size(index)
if self.token is not None:
n += 1
return n
Loading

0 comments on commit a92bcda

Please sign in to comment.