Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
IMDB Model (#297)
Browse files Browse the repository at this point in the history
* Adds a transformer based classification model, plus a Tango config that runs it

* Fix for Piqa

* Changelog

* Formatting

* Remix the dataset we use

* Actually use the validation split

* Rename the classification model
  • Loading branch information
dirkgr committed Aug 28, 2021
1 parent 8d5b21f commit 54de9d6
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
beam search and other options.
- Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs).
- Added a configuration to train on the PIQA dataset with AllenNLP Tango.
- Added a transformer classification model.
- Added a configuration to train on the IMDB dataset with AllenNLP Tango.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions allennlp_models/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa: F403
from allennlp_models.classification.models import *
from allennlp_models.classification.dataset_readers import *
from allennlp_models.classification.tango import *
3 changes: 3 additions & 0 deletions allennlp_models/classification/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from allennlp_models.classification.models.biattentive_classification_network import (
BiattentiveClassificationNetwork,
)
from allennlp_models.classification.models.transformer_classification_tt import (
TransformerClassificationTT,
)
103 changes: 103 additions & 0 deletions allennlp_models/classification/models/transformer_classification_tt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from typing import Dict, Optional

import torch
from allennlp.data import Vocabulary
from allennlp.models import Model
from allennlp.modules.transformer import TransformerEmbeddings, TransformerStack, TransformerPooler
from torch.nn import Dropout

logger = logging.getLogger(__name__)


@Model.register("transformer_classification_tt")
class TransformerClassificationTT(Model):
"""
This class implements a classification patterned after the proposed model in
[RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al)]
(https://api.semanticscholar.org/CorpusID:198953378).
Parameters
----------
vocab : ``Vocabulary``
transformer_model : ``str``, optional (default=``"roberta-large"``)
This model chooses the embedder according to this setting. You probably want to make sure this matches the
setting in the reader.
"""

def __init__(
self,
vocab: Vocabulary,
transformer_model: str = "roberta-large",
num_labels: Optional[int] = None,
label_namespace: str = "labels",
override_weights_file: Optional[str] = None,
**kwargs,
) -> None:
super().__init__(vocab, **kwargs)
transformer_kwargs = {
"model_name": transformer_model,
"weights_path": override_weights_file,
}
self.embeddings = TransformerEmbeddings.from_pretrained_module(**transformer_kwargs)
self.transformer_stack = TransformerStack.from_pretrained_module(**transformer_kwargs)
self.pooler = TransformerPooler.from_pretrained_module(**transformer_kwargs)
self.pooler_dropout = Dropout(p=0.1)

self.label_tokens = vocab.get_index_to_token_vocabulary(label_namespace)
if num_labels is None:
num_labels = len(self.label_tokens)
self.linear_layer = torch.nn.Linear(self.pooler.get_output_dim(), num_labels)
self.linear_layer.weight.data.normal_(mean=0.0, std=0.02)
self.linear_layer.bias.data.zero_()

from allennlp.training.metrics import CategoricalAccuracy, FBetaMeasure

self.loss = torch.nn.CrossEntropyLoss()
self.acc = CategoricalAccuracy()
self.f1 = FBetaMeasure()

def forward( # type: ignore
self,
text: Dict[str, torch.Tensor],
label: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Parameters
----------
text : ``Dict[str, torch.LongTensor]``
From a ``TensorTextField``. Contains the text to be classified.
label : ``Optional[torch.LongTensor]``
From a ``LabelField``, specifies the true class of the instance
Returns
-------
An output dictionary consisting of:
loss : ``torch.FloatTensor``, optional
A scalar loss to be optimised. This is only returned when `correct_alternative` is not `None`.
logits : ``torch.FloatTensor``
The logits for every possible answer choice
"""
embedded_alternatives = self.embeddings(**text)
embedded_alternatives = self.transformer_stack(
embedded_alternatives, text["attention_mask"]
)
embedded_alternatives = self.pooler(embedded_alternatives.final_hidden_states)
embedded_alternatives = self.pooler_dropout(embedded_alternatives)
logits = self.linear_layer(embedded_alternatives)

result = {"logits": logits, "answers": logits.argmax(1)}

if label is not None:
result["loss"] = self.loss(logits, label)
self.acc(logits, label)
self.f1(logits, label)

return result

def get_metrics(self, reset: bool = False) -> Dict[str, float]:
result = {"acc": self.acc.get_metric(reset)}
for metric_name, metrics_per_class in self.f1.get_metric(reset).items():
for class_index, value in enumerate(metrics_per_class):
result[f"{self.label_tokens[class_index]}-{metric_name}"] = value
return result
1 change: 1 addition & 0 deletions allennlp_models/classification/tango/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from allennlp_models.classification.tango.imdb import ImdbInstances
65 changes: 65 additions & 0 deletions allennlp_models/classification/tango/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Sequence, Dict

import datasets
from allennlp.common import cached_transformers
from allennlp.data import Vocabulary, Instance
from allennlp.data.fields import TransformerTextField, LabelField
from allennlp.tango.dataset import DatasetDict
from allennlp.tango.sqlite_format import SqliteDictFormat
from allennlp.tango.step import Step


@Step.register("imdb_instances")
class ImdbInstances(Step):
DETERMINISTIC = True
VERSION = "003"
CACHEABLE = True

FORMAT = SqliteDictFormat

def run(
self,
tokenizer_name: str,
max_length: int = 512,
) -> DatasetDict:
tokenizer = cached_transformers.get_tokenizer(tokenizer_name)
assert tokenizer.pad_token_type_id == 0

def clean_text(s: str) -> str:
return s.replace("<br />", "\n")

# This thing is so complicated because we want to call `batch_encode_plus` with all
# the strings at once.
results: Dict[str, Sequence[Instance]] = {}
for split_name, instances in datasets.load_dataset("imdb").items():
tokenized_texts = tokenizer.batch_encode_plus(
[clean_text(instance["text"]) for instance in instances],
add_special_tokens=True,
truncation=True,
max_length=max_length,
return_token_type_ids=True,
return_attention_mask=False,
)

results[split_name] = [
Instance(
{
"text": TransformerTextField(
input_ids=input_ids,
token_type_ids=token_type_ids,
padding_token_id=tokenizer.pad_token_id,
),
"label": LabelField(instance["label"], skip_indexing=True),
}
)
for instance, input_ids, token_type_ids in zip(
instances, tokenized_texts["input_ids"], tokenized_texts["token_type_ids"]
)
]

# make vocab
vocab = Vocabulary.empty()
vocab.add_transformer_vocab(tokenizer, "tokens")
vocab.add_tokens_to_namespace(["neg", "pos"], "labels")

return DatasetDict(results, vocab)
3 changes: 2 additions & 1 deletion allennlp_models/mc/tango/piqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@Step.register("piqa_instances")
class PiqaInstances(Step):
DETERMINISTIC = True
VERSION = "003"
VERSION = "004"
CACHEABLE = True

def run(
Expand Down Expand Up @@ -63,6 +63,7 @@ def run(
torch.tensor(
tokenized_alts["token_type_ids"][alt_index], dtype=torch.int32
),
padding_token_id=tokenizer.pad_token_id,
)
for alt_index in [2 * i, 2 * i + 1]
]
Expand Down
71 changes: 71 additions & 0 deletions training_config/tango/imdb.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@

local debug = true;

local transformer_model = if debug then "roberta-base" else "roberta-large";

{
"steps": {
"original_dataset": {
"type": "imdb_instances",
"tokenizer_name": transformer_model
},
"remixed_dataset": {
"type": "dataset_remix",
"input": { "ref": "original_dataset" },
"new_splits": {
"train": "train[:20000]",
"validation": "train[20000:]",
"test": "test"
},
"keep_old_splits": false,
"shuffle_before": true,
},
"trained_model": {
"type": "training",
"dataset": { "ref": "remixed_dataset" },
"training_split": "train",
"validation_split": "validation",
[if !debug then "data_loader"]: {
"batch_size": 32
},
[if debug then "data_loader"]: {
"type": "max_batches",
"max_batches_per_epoch": 7,
"inner": {
"batch_size": 5
}
},
"model": {
"type": "transformer_classification_tt",
"transformer_model": transformer_model,
},
"optimizer": {
"type": "huggingface_adamw",
"weight_decay": 0.01,
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]],
"lr": 1e-5,
"eps": 1e-8,
"correct_bias": true
},
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps": 100
},
"num_epochs": if debug then 3 else 20,
"patience": 3,
},
"evaluation": {
"type": "evaluation",
"dataset": { "ref": "dataset" },
"model": { "ref": "trained_model" },
"split": "test",
[if debug then "data_loader"]: {
"type": "max_batches",
"max_batches_per_epoch": 7,
"inner": {
"batch_size": 5
}
},
}
}
}

0 comments on commit 54de9d6

Please sign in to comment.