This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add T5 for generation/summarization (#241)
* add T5 CNN / DM config * updates * changelog * fix * fix config * clean up * Apply suggestions from code review
- Loading branch information
Showing
9 changed files
with
306 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Optional, Dict, Any | ||
|
||
from overrides import overrides | ||
import torch | ||
|
||
from allennlp.data import TextFieldTensors, Vocabulary | ||
from allennlp.data.tokenizers import PretrainedTransformerTokenizer | ||
from allennlp.models.model import Model | ||
from allennlp.modules.transformer.t5 import T5 as T5Module, T5Output, IntT, BoolT | ||
from allennlp.training.metrics import ROUGE, BLEU | ||
|
||
|
||
@Model.register("t5") | ||
class T5(Model): | ||
def __init__(self, vocab: Vocabulary, model_name: str, **kwargs) -> None: | ||
super().__init__(vocab, **kwargs) | ||
self._model_name = model_name | ||
# We only instantiate this when we need it. | ||
self._tokenizer: Optional[PretrainedTransformerTokenizer] = None | ||
self.t5 = T5Module.from_pretrained_module(model_name) | ||
|
||
exclude_indices = { | ||
self.t5.pad_token_id, | ||
self.t5.decoder_start_token_id, | ||
self.t5.eos_token_id, | ||
} | ||
self._metrics = [ | ||
ROUGE(exclude_indices=exclude_indices), | ||
BLEU(exclude_indices=exclude_indices), | ||
] | ||
|
||
@property | ||
def tokenizer(self) -> PretrainedTransformerTokenizer: | ||
if self._tokenizer is None: | ||
self._tokenizer = PretrainedTransformerTokenizer(self._model_name) | ||
return self._tokenizer | ||
|
||
def forward( # type: ignore | ||
self, source_tokens: TextFieldTensors, target_tokens: Optional[TextFieldTensors] = None | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Performs the forward step of T5. | ||
# Parameters | ||
source_tokens : `TextFieldTensors`, required | ||
The source tokens for the encoder. We assume they are stored under the `tokens` key/namespace. | ||
target_tokens : `TextFieldTensors`, optional (default = `None`) | ||
The target tokens for the decoder. We assume they are also stored under the `tokens` key/namespace. | ||
If no target tokens are given during training / validation, the source tokens are shifted | ||
to the right by 1. | ||
# Returns | ||
`Dict[str, torch.Tensor]` | ||
Contains the `loss` when `target_tokens` is provided. | ||
And during prediction, includes `predictions` and `predicted_log_probs` from beam search. | ||
""" | ||
input_ids, attention_mask = ( | ||
source_tokens["tokens"]["token_ids"], | ||
source_tokens["tokens"]["mask"], | ||
) | ||
labels: Optional[IntT] = None | ||
decoder_attention_mask: Optional[BoolT] = None | ||
if target_tokens is not None: | ||
labels, decoder_attention_mask = ( | ||
target_tokens["tokens"]["token_ids"], # type: ignore[assignment] | ||
target_tokens["tokens"]["mask"], # type: ignore[assignment] | ||
) | ||
elif self.training: | ||
raise ValueError("'target_tokens' required during training") | ||
|
||
output: T5Output = self.t5( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
labels=labels, | ||
decoder_attention_mask=decoder_attention_mask, | ||
) | ||
output_dict: Dict[str, torch.Tensor] = {} | ||
|
||
if self.training: | ||
assert output.loss is not None | ||
output_dict["loss"] = output.loss | ||
else: | ||
# Shape: (batch_size, beam_size, num_tokens) | ||
assert output.predictions is not None | ||
# Shape: (batch_size, beam_size) | ||
assert output.predicted_log_probs is not None | ||
# Shape: (batch_size, num_tokens) | ||
output_dict["predictions"] = output.predictions[:, 0, :] | ||
# Shape: (batch_size, ) | ||
output_dict["predicted_log_probs"] = output.predicted_log_probs[:, 0] | ||
|
||
if labels is not None: | ||
assert output.loss is not None | ||
output_dict["loss"] = output.loss | ||
|
||
for metric in self._metrics: | ||
metric(output_dict["predictions"], labels) # type: ignore[call-arg] | ||
|
||
return output_dict | ||
|
||
@overrides | ||
def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: | ||
predictions = output_dict["predictions"] | ||
output_dict["predicted_text"] = self.tokenizer.tokenizer.batch_decode( | ||
predictions, skip_special_tokens=True # type: ignore[attr-defined] | ||
) | ||
return output_dict | ||
|
||
@overrides | ||
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | ||
metrics: Dict[str, float] = {} | ||
if not self.training: | ||
for metric in self._metrics: | ||
metrics.update(metric.get_metric(reset=reset)) | ||
return metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
local model_name = "patrickvonplaten/t5-tiny-random"; | ||
local data_base_url = "test_fixtures/generation/bart/data/"; | ||
|
||
{ | ||
"train_data_path": data_base_url + "/url_lists/all_train.txt", | ||
"validation_data_path": data_base_url + "/url_lists/all_val.txt", | ||
"dataset_reader": { | ||
"type": "cnn_dm", | ||
"source_tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name | ||
}, | ||
"source_token_indexers": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name, | ||
"namespace": "tokens" | ||
} | ||
}, | ||
"source_max_tokens": 512, | ||
"target_max_tokens": 54, | ||
}, | ||
"model": { | ||
"type": "t5", | ||
"model_name": model_name | ||
}, | ||
"data_loader": { | ||
"batch_size": 2, | ||
"shuffle": true | ||
}, | ||
"trainer": { | ||
"num_epochs": 1, | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"lr": 3e-5, | ||
"betas": [0.9, 0.999], | ||
"eps": 1e-8, | ||
"correct_bias": true | ||
}, | ||
"learning_rate_scheduler": { | ||
"type": "polynomial_decay", | ||
}, | ||
"grad_norm": 1.0, | ||
"enable_default_callbacks": false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from allennlp.common.testing import ModelTestCase | ||
|
||
from tests import FIXTURES_ROOT | ||
|
||
from allennlp_models import generation # noqa: F401 | ||
|
||
|
||
class T5Test(ModelTestCase): | ||
def setup_method(self): | ||
super().setup_method() | ||
self.set_up_model( | ||
FIXTURES_ROOT / "generation" / "t5" / "experiment.jsonnet", | ||
FIXTURES_ROOT / "generation" / "bart" / "data" / "url_lists" / "all_train.txt", | ||
) | ||
|
||
def test_model_can_train_save_load_predict(self): | ||
self.ensure_model_can_train_save_and_load(self.param_file, tolerance=1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
local model_name = "t5-small"; // TODO: change to large model | ||
local data_base_url = "https://storage.googleapis.com/allennlp-public-data/cnndm-combined-data-2020.07.13.tar.gz"; | ||
local train_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_train.txt"; | ||
local dev_data = data_base_url + "!cnndm-combined-data-2020.07.13/url_lists/all_val.txt"; | ||
|
||
{ | ||
"train_data_path": train_data, | ||
"validation_data_path": dev_data, | ||
"dataset_reader": { | ||
"type": "cnn_dm", | ||
"source_tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name, | ||
}, | ||
"source_token_indexers": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": model_name, | ||
"namespace": "tokens", | ||
} | ||
}, | ||
"source_max_tokens": 512, | ||
"target_max_tokens": 54, | ||
"source_prefix": "summarize: ", | ||
"max_instances": 1000 // DEBUG setting | ||
}, | ||
"model": { | ||
"type": "t5", | ||
"model_name": model_name, | ||
}, | ||
"data_loader": { | ||
"batch_size": 4, | ||
"shuffle": true, | ||
}, | ||
"trainer": { | ||
"num_epochs": 3, | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"lr": 3e-5, | ||
"betas": [0.9, 0.999], | ||
"eps": 1e-8, | ||
"correct_bias": true, | ||
}, | ||
"learning_rate_scheduler": { | ||
"type": "polynomial_decay", | ||
}, | ||
"grad_norm": 1.0, | ||
} | ||
} |