This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Adds a transformer based classification model, plus a Tango config that runs it * Fix for Piqa * Changelog * Formatting * Remix the dataset we use * Actually use the validation split * Rename the classification model
- Loading branch information
Showing
8 changed files
with
248 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# flake8: noqa: F403 | ||
from allennlp_models.classification.models import * | ||
from allennlp_models.classification.dataset_readers import * | ||
from allennlp_models.classification.tango import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from allennlp_models.classification.models.biattentive_classification_network import ( | ||
BiattentiveClassificationNetwork, | ||
) | ||
from allennlp_models.classification.models.transformer_classification_tt import ( | ||
TransformerClassificationTT, | ||
) |
103 changes: 103 additions & 0 deletions
103
allennlp_models/classification/models/transformer_classification_tt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import logging | ||
from typing import Dict, Optional | ||
|
||
import torch | ||
from allennlp.data import Vocabulary | ||
from allennlp.models import Model | ||
from allennlp.modules.transformer import TransformerEmbeddings, TransformerStack, TransformerPooler | ||
from torch.nn import Dropout | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@Model.register("transformer_classification_tt") | ||
class TransformerClassificationTT(Model): | ||
""" | ||
This class implements a classification patterned after the proposed model in | ||
[RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al)] | ||
(https://api.semanticscholar.org/CorpusID:198953378). | ||
Parameters | ||
---------- | ||
vocab : ``Vocabulary`` | ||
transformer_model : ``str``, optional (default=``"roberta-large"``) | ||
This model chooses the embedder according to this setting. You probably want to make sure this matches the | ||
setting in the reader. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
vocab: Vocabulary, | ||
transformer_model: str = "roberta-large", | ||
num_labels: Optional[int] = None, | ||
label_namespace: str = "labels", | ||
override_weights_file: Optional[str] = None, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(vocab, **kwargs) | ||
transformer_kwargs = { | ||
"model_name": transformer_model, | ||
"weights_path": override_weights_file, | ||
} | ||
self.embeddings = TransformerEmbeddings.from_pretrained_module(**transformer_kwargs) | ||
self.transformer_stack = TransformerStack.from_pretrained_module(**transformer_kwargs) | ||
self.pooler = TransformerPooler.from_pretrained_module(**transformer_kwargs) | ||
self.pooler_dropout = Dropout(p=0.1) | ||
|
||
self.label_tokens = vocab.get_index_to_token_vocabulary(label_namespace) | ||
if num_labels is None: | ||
num_labels = len(self.label_tokens) | ||
self.linear_layer = torch.nn.Linear(self.pooler.get_output_dim(), num_labels) | ||
self.linear_layer.weight.data.normal_(mean=0.0, std=0.02) | ||
self.linear_layer.bias.data.zero_() | ||
|
||
from allennlp.training.metrics import CategoricalAccuracy, FBetaMeasure | ||
|
||
self.loss = torch.nn.CrossEntropyLoss() | ||
self.acc = CategoricalAccuracy() | ||
self.f1 = FBetaMeasure() | ||
|
||
def forward( # type: ignore | ||
self, | ||
text: Dict[str, torch.Tensor], | ||
label: Optional[torch.Tensor] = None, | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Parameters | ||
---------- | ||
text : ``Dict[str, torch.LongTensor]`` | ||
From a ``TensorTextField``. Contains the text to be classified. | ||
label : ``Optional[torch.LongTensor]`` | ||
From a ``LabelField``, specifies the true class of the instance | ||
Returns | ||
------- | ||
An output dictionary consisting of: | ||
loss : ``torch.FloatTensor``, optional | ||
A scalar loss to be optimised. This is only returned when `correct_alternative` is not `None`. | ||
logits : ``torch.FloatTensor`` | ||
The logits for every possible answer choice | ||
""" | ||
embedded_alternatives = self.embeddings(**text) | ||
embedded_alternatives = self.transformer_stack( | ||
embedded_alternatives, text["attention_mask"] | ||
) | ||
embedded_alternatives = self.pooler(embedded_alternatives.final_hidden_states) | ||
embedded_alternatives = self.pooler_dropout(embedded_alternatives) | ||
logits = self.linear_layer(embedded_alternatives) | ||
|
||
result = {"logits": logits, "answers": logits.argmax(1)} | ||
|
||
if label is not None: | ||
result["loss"] = self.loss(logits, label) | ||
self.acc(logits, label) | ||
self.f1(logits, label) | ||
|
||
return result | ||
|
||
def get_metrics(self, reset: bool = False) -> Dict[str, float]: | ||
result = {"acc": self.acc.get_metric(reset)} | ||
for metric_name, metrics_per_class in self.f1.get_metric(reset).items(): | ||
for class_index, value in enumerate(metrics_per_class): | ||
result[f"{self.label_tokens[class_index]}-{metric_name}"] = value | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from allennlp_models.classification.tango.imdb import ImdbInstances |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
from typing import Sequence, Dict | ||
|
||
import datasets | ||
from allennlp.common import cached_transformers | ||
from allennlp.data import Vocabulary, Instance | ||
from allennlp.data.fields import TransformerTextField, LabelField | ||
from allennlp.tango.dataset import DatasetDict | ||
from allennlp.tango.sqlite_format import SqliteDictFormat | ||
from allennlp.tango.step import Step | ||
|
||
|
||
@Step.register("imdb_instances") | ||
class ImdbInstances(Step): | ||
DETERMINISTIC = True | ||
VERSION = "003" | ||
CACHEABLE = True | ||
|
||
FORMAT = SqliteDictFormat | ||
|
||
def run( | ||
self, | ||
tokenizer_name: str, | ||
max_length: int = 512, | ||
) -> DatasetDict: | ||
tokenizer = cached_transformers.get_tokenizer(tokenizer_name) | ||
assert tokenizer.pad_token_type_id == 0 | ||
|
||
def clean_text(s: str) -> str: | ||
return s.replace("<br />", "\n") | ||
|
||
# This thing is so complicated because we want to call `batch_encode_plus` with all | ||
# the strings at once. | ||
results: Dict[str, Sequence[Instance]] = {} | ||
for split_name, instances in datasets.load_dataset("imdb").items(): | ||
tokenized_texts = tokenizer.batch_encode_plus( | ||
[clean_text(instance["text"]) for instance in instances], | ||
add_special_tokens=True, | ||
truncation=True, | ||
max_length=max_length, | ||
return_token_type_ids=True, | ||
return_attention_mask=False, | ||
) | ||
|
||
results[split_name] = [ | ||
Instance( | ||
{ | ||
"text": TransformerTextField( | ||
input_ids=input_ids, | ||
token_type_ids=token_type_ids, | ||
padding_token_id=tokenizer.pad_token_id, | ||
), | ||
"label": LabelField(instance["label"], skip_indexing=True), | ||
} | ||
) | ||
for instance, input_ids, token_type_ids in zip( | ||
instances, tokenized_texts["input_ids"], tokenized_texts["token_type_ids"] | ||
) | ||
] | ||
|
||
# make vocab | ||
vocab = Vocabulary.empty() | ||
vocab.add_transformer_vocab(tokenizer, "tokens") | ||
vocab.add_tokens_to_namespace(["neg", "pos"], "labels") | ||
|
||
return DatasetDict(results, vocab) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
|
||
local debug = true; | ||
|
||
local transformer_model = if debug then "roberta-base" else "roberta-large"; | ||
|
||
{ | ||
"steps": { | ||
"original_dataset": { | ||
"type": "imdb_instances", | ||
"tokenizer_name": transformer_model | ||
}, | ||
"remixed_dataset": { | ||
"type": "dataset_remix", | ||
"input": { "ref": "original_dataset" }, | ||
"new_splits": { | ||
"train": "train[:20000]", | ||
"validation": "train[20000:]", | ||
"test": "test" | ||
}, | ||
"keep_old_splits": false, | ||
"shuffle_before": true, | ||
}, | ||
"trained_model": { | ||
"type": "training", | ||
"dataset": { "ref": "remixed_dataset" }, | ||
"training_split": "train", | ||
"validation_split": "validation", | ||
[if !debug then "data_loader"]: { | ||
"batch_size": 32 | ||
}, | ||
[if debug then "data_loader"]: { | ||
"type": "max_batches", | ||
"max_batches_per_epoch": 7, | ||
"inner": { | ||
"batch_size": 5 | ||
} | ||
}, | ||
"model": { | ||
"type": "transformer_classification_tt", | ||
"transformer_model": transformer_model, | ||
}, | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"weight_decay": 0.01, | ||
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]], | ||
"lr": 1e-5, | ||
"eps": 1e-8, | ||
"correct_bias": true | ||
}, | ||
"learning_rate_scheduler": { | ||
"type": "linear_with_warmup", | ||
"warmup_steps": 100 | ||
}, | ||
"num_epochs": if debug then 3 else 20, | ||
"patience": 3, | ||
}, | ||
"evaluation": { | ||
"type": "evaluation", | ||
"dataset": { "ref": "dataset" }, | ||
"model": { "ref": "trained_model" }, | ||
"split": "test", | ||
[if debug then "data_loader"]: { | ||
"type": "max_batches", | ||
"max_batches_per_epoch": 7, | ||
"inner": { | ||
"batch_size": 5 | ||
} | ||
}, | ||
} | ||
} | ||
} |