Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add T5 for generation/summarization (#241)
Browse files Browse the repository at this point in the history
* add T5 CNN / DM config

* updates

* changelog

* fix

* fix config

* clean up

* Apply suggestions from code review
  • Loading branch information
epwalsh committed Apr 22, 2021
1 parent 5012f23 commit 7a6ee0c
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 11 deletions.
13 changes: 8 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `T5` model for generation.
- Added a classmethod constructor on `Seq2SeqPredictor`: `.pretrained_t5_for_generation()`.
- Added a parameter called `source_prefix` to `CNNDailyMailDatasetReader`. This is useful with T5, for example, by setting `source_prefix` to "summarization: ".
- Tests for `VqaMeasure`.
- Distributed tests for `ConllCorefScores` and `SrlEvalScorer` metrics.

### Fixed

- `VqaMeasure` now calculates correctly in the distributed case.
- `ConllCorefScores` now calculates correctly in the distributed case.
- `SrlEvalScorer` raises an appropriate error if run in the distributed setting.

### Added

- Tests for `VqaMeasure`.
- Distributed tests for `ConllCorefScores` and `SrlEvalScorer` metrics.

### Changed

- Updated `registered_predictor_name` to `null` in model cards for the models where it was the same as the default predictor.
Expand Down
3 changes: 1 addition & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ DOCKER_RUN_CMD = docker run --rm \
-v $$HOME/.cache/huggingface:/root/.cache/huggingface \
-v $$HOME/nltk_data:/root/nltk_data

# TODO: change this back to master branch
ALLENNLP_COMMIT_SHA = $(shell git ls-remote https://github.com/allenai/allennlp main | cut -f 1)

ifeq ($(shell uname),Darwin)
Expand Down Expand Up @@ -58,7 +57,7 @@ format :

.PHONY : typecheck
typecheck :
mypy allennlp_models tests --ignore-missing-imports --no-strict-optional --no-site-packages
mypy allennlp_models tests --ignore-missing-imports --no-strict-optional --no-site-packages --cache-dir=/dev/null

.PHONY : test
test :
Expand Down
12 changes: 11 additions & 1 deletion allennlp_models/generation/dataset_readers/cnn_dm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class CNNDailyMailDatasetReader(DatasetReader):
Maximum number of tokens in source sequence.
target_max_tokens : `int`, optional
Maximum number of tokens in target sequence.
source_prefix : `str`, optional
An optional prefix to prepend to source strings. For example, with a T5 model,
you want to set the `source_prefix` to "summarize: ".
"""

def __init__(
Expand All @@ -56,6 +59,7 @@ def __init__(
target_token_indexers: Dict[str, TokenIndexer] = None,
source_max_tokens: Optional[int] = None,
target_max_tokens: Optional[int] = None,
source_prefix: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -67,6 +71,7 @@ def __init__(
self._target_token_indexers = target_token_indexers or self._source_token_indexers
self._source_max_tokens = source_max_tokens
self._target_max_tokens = target_max_tokens
self._source_prefix = source_prefix

@staticmethod
def _hashhex(url):
Expand Down Expand Up @@ -155,7 +160,12 @@ def _read(self, file_path: str):
def text_to_instance(
self, source_sequence: str, target_sequence: str = None
) -> Instance: # type: ignore
tokenized_source = self._source_tokenizer.tokenize(source_sequence)
if self._source_prefix is not None:
tokenized_source = self._source_tokenizer.tokenize(
self._source_prefix + source_sequence
)
else:
tokenized_source = self._source_tokenizer.tokenize(source_sequence)
if self._source_max_tokens is not None and len(tokenized_source) > self._source_max_tokens:
tokenized_source = tokenized_source[: self._source_max_tokens]

Expand Down
1 change: 1 addition & 0 deletions allennlp_models/generation/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from allennlp_models.generation.models.copynet_seq2seq import CopyNetSeq2Seq
from allennlp_models.generation.models.simple_seq2seq import SimpleSeq2Seq
from allennlp_models.generation.models.bart import Bart
from allennlp_models.generation.models.t5 import T5
119 changes: 119 additions & 0 deletions allennlp_models/generation/models/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from typing import Optional, Dict, Any

from overrides import overrides
import torch

from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.models.model import Model
from allennlp.modules.transformer.t5 import T5 as T5Module, T5Output, IntT, BoolT
from allennlp.training.metrics import ROUGE, BLEU


@Model.register("t5")
class T5(Model):
def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None:
super().__init__(vocab, **kwargs)
self._model_name = model_name
# We only instantiate this when we need it.
self._tokenizer: Optional[PretrainedTransformerTokenizer] = None
self.t5 = T5Module.from_pretrained_module(model_name)

exclude_indices = {
self.t5.pad_token_id,
self.t5.decoder_start_token_id,
self.t5.eos_token_id,
}
self._metrics = [
ROUGE(exclude_indices=exclude_indices),
BLEU(exclude_indices=exclude_indices),
]

@property
def tokenizer(self) -> PretrainedTransformerTokenizer:
if self._tokenizer is None:
self._tokenizer = PretrainedTransformerTokenizer(self._model_name)
return self._tokenizer

def forward( # type: ignore
self, source_tokens: TextFieldTensors, target_tokens: Optional[TextFieldTensors] = None
) -> Dict[str, torch.Tensor]:
"""
Performs the forward step of T5.
# Parameters
source_tokens : `TextFieldTensors`, required
The source tokens for the encoder. We assume they are stored under the `tokens` key/namespace.
target_tokens : `TextFieldTensors`, optional (default = `None`)
The target tokens for the decoder. We assume they are also stored under the `tokens` key/namespace.
If no target tokens are given during training / validation, the source tokens are shifted
to the right by 1.
# Returns
`Dict[str, torch.Tensor]`
Contains the `loss` when `target_tokens` is provided.
And during prediction, includes `predictions` and `predicted_log_probs` from beam search.
"""
input_ids, attention_mask = (
source_tokens["tokens"]["token_ids"],
source_tokens["tokens"]["mask"],
)
labels: Optional[IntT] = None
decoder_attention_mask: Optional[BoolT] = None
if target_tokens is not None:
labels, decoder_attention_mask = (
target_tokens["tokens"]["token_ids"], # type: ignore[assignment]
target_tokens["tokens"]["mask"], # type: ignore[assignment]
)
elif self.training:
raise ValueError("'target_tokens' required during training")

output: T5Output = self.t5(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask,
)
output_dict: Dict[str, torch.Tensor] = {}

if self.training:
assert output.loss is not None
output_dict["loss"] = output.loss
else:
# Shape: (batch_size, beam_size, num_tokens)
assert output.predictions is not None
# Shape: (batch_size, beam_size)
assert output.predicted_log_probs is not None
# Shape: (batch_size, num_tokens)
output_dict["predictions"] = output.predictions[:, 0, :]
# Shape: (batch_size, )
output_dict["predicted_log_probs"] = output.predicted_log_probs[:, 0]

if labels is not None:
assert output.loss is not None
output_dict["loss"] = output.loss

for metric in self._metrics:
metric(output_dict["predictions"], labels) # type: ignore[call-arg]

return output_dict

@overrides
def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]:
predictions = output_dict["predictions"]
output_dict["predicted_text"] = self.tokenizer.tokenizer.batch_decode(
predictions, skip_special_tokens=True # type: ignore[attr-defined]
)
return output_dict

@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics: Dict[str, float] = {}
if not self.training:
for metric in self._metrics:
metrics.update(metric.get_metric(reset=reset))
return metrics
57 changes: 54 additions & 3 deletions allennlp_models/generation/predictors/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
class Seq2SeqPredictor(Predictor):
"""
Predictor for sequence to sequence models, including
[`ComposedSeq2Seq`](../models/encoder_decoders/composed_seq2seq.md) and
[`SimpleSeq2Seq`](../models/encoder_decoders/simple_seq2seq.md) and
[`CopyNetSeq2Seq`](../models/encoder_decoders/copynet_seq2seq.md).
- [`ComposedSeq2Seq`](../models/composed_seq2seq.md),
- [`SimpleSeq2Seq`](../models/simple_seq2seq.md),
- [`CopyNetSeq2Seq`](../models/copynet_seq2seq.md),
- [`Bart`](../models/bart.md), and
- [`T5`](../models/t5.md).
"""

def predict(self, source: str) -> JsonDict:
Expand All @@ -24,3 +27,51 @@ def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
source = json_dict["source"]
return self._dataset_reader.text_to_instance(source)

@classmethod
def pretrained_t5_for_generation(cls, model_name: str = "t5-base") -> "Seq2SeqPredictor":
"""
A helper method for creating a predictor for a pretrained T5 generation model.
# Examples
```python
from allennlp_models.generation.predictors import Seq2SeqPredictor
ARTICLE_TO_SUMMARIZE = '''
summarize: The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building,
and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side.
During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest
man-made structure in the world, a title it held for 41 years until the Chrysler Building in
New York City was finished in 1930. It was the first structure to reach a height of 300 metres.
Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller
than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is
the second tallest free-standing structure in France after the Millau Viaduct.
'''.strip().replace(
"\n", " "
)
predictor = Seq2SeqPredictor.pretrained_t5_for_generation("t5-small")
predictor.predict(ARTICLE_TO_SUMMARIZE)
```
"""
from allennlp.data import Vocabulary
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp_models.generation.dataset_readers import Seq2SeqDatasetReader
from allennlp_models.generation.models import T5

tokenizer, token_indexer = (
PretrainedTransformerTokenizer(model_name),
PretrainedTransformerIndexer(model_name),
)
reader = Seq2SeqDatasetReader(
source_tokenizer=tokenizer,
source_token_indexers={"tokens": token_indexer},
source_add_start_token=False,
source_add_end_token=False,
target_add_start_token=False,
target_add_end_token=False,
)
vocab = Vocabulary.from_pretrained_transformer(model_name)
model = T5(vocab, model_name)
return cls(model, reader)
46 changes: 46 additions & 0 deletions test_fixtures/generation/t5/experiment.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
local model_name = "patrickvonplaten/t5-tiny-random";
local data_base_url = "test_fixtures/generation/bart/data/";

{
"train_data_path": data_base_url + "/url_lists/all_train.txt",
"validation_data_path": data_base_url + "/url_lists/all_val.txt",
"dataset_reader": {
"type": "cnn_dm",
"source_tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name
},
"source_token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
"namespace": "tokens"
}
},
"source_max_tokens": 512,
"target_max_tokens": 54,
},
"model": {
"type": "t5",
"model_name": model_name
},
"data_loader": {
"batch_size": 2,
"shuffle": true
},
"trainer": {
"num_epochs": 1,
"optimizer": {
"type": "huggingface_adamw",
"lr": 3e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"correct_bias": true
},
"learning_rate_scheduler": {
"type": "polynomial_decay",
},
"grad_norm": 1.0,
"enable_default_callbacks": false
}
}
17 changes: 17 additions & 0 deletions tests/generation/models/t5_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from allennlp.common.testing import ModelTestCase

from tests import FIXTURES_ROOT

from allennlp_models import generation # noqa: F401


class T5Test(ModelTestCase):
def setup_method(self):
super().setup_method()
self.set_up_model(
FIXTURES_ROOT / "generation" / "t5" / "experiment.jsonnet",
FIXTURES_ROOT / "generation" / "bart" / "data" / "url_lists" / "all_train.txt",
)

def test_model_can_train_save_load_predict(self):
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-2)
49 changes: 49 additions & 0 deletions training_config/generation/t5_cnn_dm.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
local model_name = "t5-small"; // TODO: change to large model
local data_base_url = "https://storage.googleapis.com/allennlp-public-data/cnndm-combined-data-2020.07.13.tar.gz";
local train_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_train.txt";
local dev_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_val.txt";

{
"train_data_path": train_data,
"validation_data_path": dev_data,
"dataset_reader": {
"type": "cnn_dm",
"source_tokenizer": {
"type": "pretrained_transformer",
"model_name": model_name,
},
"source_token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": model_name,
"namespace": "tokens",
}
},
"source_max_tokens": 512,
"target_max_tokens": 54,
"source_prefix": "summarize: ",
"max_instances": 1000 // DEBUG setting
},
"model": {
"type": "t5",
"model_name": model_name,
},
"data_loader": {
"batch_size": 4,
"shuffle": true,
},
"trainer": {
"num_epochs": 3,
"optimizer": {
"type": "huggingface_adamw",
"lr": 3e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"correct_bias": true,
},
"learning_rate_scheduler": {
"type": "polynomial_decay",
},
"grad_norm": 1.0,
}
}

0 comments on commit 7a6ee0c

Please sign in to comment.