diff --git a/CHANGELOG.md b/CHANGELOG.md
index fd3350e36..3eb0861eb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added script that produces the coref training data.
- Added tests for using `allennlp predict` on multitask models.
-
+- Added reader and training config for RoBERTa on SuperGLUE's Recognizing Textual Entailment task
## [v2.2.0](https://github.com/allenai/allennlp-models/releases/tag/v2.2.0) - 2021-03-26
diff --git a/README.md b/README.md
index 94b220734..4aa46ccfc 100644
--- a/README.md
+++ b/README.md
@@ -156,6 +156,7 @@ Here is a list of pre-trained models currently available.
- [`pair-classification-decomposable-attention-elmo`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-decomposable-attention-elmo.json) - The decomposable attention model (Parikh et al, 2017) combined with ELMo embeddings trained on SNLI.
- [`pair-classification-esim`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-esim.json) - Enhanced LSTM trained on SNLI.
- [`pair-classification-roberta-mnli`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-mnli.json) - RoBERTa finetuned on MNLI.
+- [`pair-classification-roberta-rte`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-rte.json) - A pair classification model patterned after the proposed model in Devlin et al, fine-tuned on the SuperGLUE RTE corpus
- [`pair-classification-roberta-snli`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-snli.json) - RoBERTa finetuned on SNLI.
- [`rc-bidaf-elmo`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/rc-bidaf-elmo.json) - BiDAF model with ELMo embeddings instead of GloVe.
- [`rc-bidaf`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/rc-bidaf.json) - BiDAF model with GloVe embeddings.
diff --git a/allennlp_models/modelcards/pair-classification-roberta-rte.json b/allennlp_models/modelcards/pair-classification-roberta-rte.json
new file mode 100644
index 000000000..b33cf021a
--- /dev/null
+++ b/allennlp_models/modelcards/pair-classification-roberta-rte.json
@@ -0,0 +1,70 @@
+{
+ "id": "pair-classification-roberta-rte",
+ "registered_model_name": "roberta-rte",
+ "registered_predictor_name": null,
+ "display_name": "RoBERTa RTE",
+ "task_id": "pair_classification",
+ "model_details": {
+ "description": "The model implements a pair classification model patterned after the proposed model in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al, 2018)](https://api.semanticscholar.org/CorpusID:52967399), fine-tuned on the MultiNLI corpus. It predicts labels with a linear layer on top of word piece embeddings.",
+ "short_description": "A pair classification model patterned after the proposed model in Devlin et al, fine-tuned on the SuperGLUE RTE corpus",
+ "developed_by": "Devlin et al",
+ "contributed_by": "Jacob Morrison",
+ "date": "2021-04-09",
+ "version": "1",
+ "model_type": "RoBERTa",
+ "paper": {
+ "citation": "\n@article{Liu2019RoBERTaAR,\ntitle={RoBERTa: A Robustly Optimized BERT Pretraining Approach},\nauthor={Y. Liu and Myle Ott and Naman Goyal and Jingfei Du and Mandar Joshi and Danqi Chen and Omer Levy and M. Lewis and L. Zettlemoyer and V. Stoyanov},\njournal={ArXiv},\nyear={2019},\nvolume={abs/1907.11692}}\n",
+ "title": "RoBERTa: A Robustly Optimized BERT Pretraining Approach",
+ "url": "https://api.semanticscholar.org/CorpusID:198953378"
+ },
+ "license": null,
+ "contact": "allennlp-contact@allenai.org"
+ },
+ "intended_use": {
+ "primary_uses": null,
+ "primary_users": null,
+ "out_of_scope_use_cases": null
+ },
+ "factors": {
+ "relevant_factors": null,
+ "evaluation_factors": null
+ },
+ "metrics": {
+ "model_performance_measures": "Accuracy",
+ "decision_thresholds": null,
+ "variation_approaches": null
+ },
+ "evaluation_data": {
+ "dataset": {
+ "name": "SuperGLUE Recognizing Textual Entailment validation set",
+ "url": "https://super.gluebenchmark.com/tasks",
+ "processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl"
+ },
+ "motivation": null,
+ "preprocessing": null
+ },
+ "training_data": {
+ "dataset": {
+ "name": "SuperGLUE Recognizing Textual Entailment training set",
+ "url": "https://super.gluebenchmark.com/tasks",
+ "processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl"
+ },
+ "motivation": null,
+ "preprocessing": null
+ },
+ "quantitative_analyses": {
+ "unitary_results": "Accuracy: 89.9% on the SuperGLUE RTE validation dataset.",
+ "intersectional_results": null
+ },
+ "model_caveats_and_recommendations": {
+ "caveats_and_recommendations": null
+ },
+ "model_ethical_considerations": {
+ "ethical_considerations": null
+ },
+ "model_usage": {
+ "archive_file": "superglue-rte-roberta.2021-04-09.tar.gz",
+ "training_config": "pair-classification/superglue_rte_roberta.jsonnet",
+ "install_instructions": "pip install allennlp==2.3.1 allennlp-models==2.3.1"
+ }
+}
diff --git a/allennlp_models/pair_classification/dataset_readers/__init__.py b/allennlp_models/pair_classification/dataset_readers/__init__.py
index ac48922f8..3e521f604 100644
--- a/allennlp_models/pair_classification/dataset_readers/__init__.py
+++ b/allennlp_models/pair_classification/dataset_readers/__init__.py
@@ -2,3 +2,6 @@
QuoraParaphraseDatasetReader,
)
from allennlp_models.pair_classification.dataset_readers.snli import SnliReader
+from allennlp_models.pair_classification.dataset_readers.transformer_superglue_rte import (
+ TransformerSuperGlueRteReader,
+)
diff --git a/allennlp_models/pair_classification/dataset_readers/transformer_superglue_rte.py b/allennlp_models/pair_classification/dataset_readers/transformer_superglue_rte.py
new file mode 100644
index 000000000..01f31e857
--- /dev/null
+++ b/allennlp_models/pair_classification/dataset_readers/transformer_superglue_rte.py
@@ -0,0 +1,117 @@
+import logging
+from typing import Any, Dict
+
+from overrides import overrides
+
+from allennlp.data.fields import MetadataField, TextField, LabelField
+from allennlp.common.file_utils import cached_path
+from allennlp.data.dataset_readers.dataset_reader import DatasetReader
+from allennlp.data.instance import Instance
+from allennlp.data.token_indexers import PretrainedTransformerIndexer
+from allennlp.data.tokenizers import PretrainedTransformerTokenizer
+
+logger = logging.getLogger(__name__)
+
+
+@DatasetReader.register("transformer_superglue_rte")
+class TransformerSuperGlueRteReader(DatasetReader):
+ """
+ Dataset reader for the SuperGLUE Recognizing Textual Entailment task, to be used with a transformer
+ model such as RoBERTa. The dataset is in the JSON Lines format.
+
+ It will generate `Instances` with the following fields:
+
+ * `tokens`, a `TextField` that contains the concatenation of premise and hypothesis,
+ * `label`, a `LabelField` containing the label, if one exists.
+ * `metadata`, a `MetadataField` that stores the instance's index in the file, the original premise,
+ the original hypothesis, both of these in tokenized form, and the gold label, accessible as
+ `metadata['index']`, `metadata['premise']`, `metadata['hypothesis']`, `metadata['tokens']`,
+ and `metadata['label']`.
+
+ # Parameters
+
+ type : `str`, optional (default=`'roberta-base'`)
+ This reader chooses tokenizer according to this setting.
+ """
+
+ def __init__(
+ self,
+ transformer_model_name: str = "roberta-base",
+ tokenizer_kwargs: Dict[str, Any] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(
+ manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
+ )
+ self._tokenizer = PretrainedTransformerTokenizer(
+ transformer_model_name,
+ add_special_tokens=False,
+ tokenizer_kwargs=tokenizer_kwargs,
+ )
+ self._token_indexers = {
+ "tokens": PretrainedTransformerIndexer(
+ transformer_model_name, tokenizer_kwargs=tokenizer_kwargs, max_length=512
+ )
+ }
+
+ @overrides
+ def _read(self, file_path: str):
+ # if `file_path` is a URL, redirect to the cache
+ file_path = cached_path(file_path, extract_archive=True)
+
+ logger.info("Reading file at %s", file_path)
+ yielded_relation_count = 0
+ from allennlp.common.file_utils import json_lines_from_file
+
+ for relation in self.shard_iterable(json_lines_from_file(file_path)):
+ premise = relation["premise"]
+ hypothesis = relation["hypothesis"]
+ if "label" in relation:
+ label = relation["label"]
+ else:
+ label = None
+ index = relation["idx"]
+
+ # todo: see if we even need this to be in a separate method
+ instance = self.text_to_instance(index, label, premise, hypothesis)
+
+ yield instance
+ yielded_relation_count += 1
+
+ @overrides
+ def text_to_instance(
+ self,
+ index: int,
+ label: str,
+ premise: str,
+ hypothesis: str,
+ ) -> Instance:
+ tokenized_premise = self._tokenizer.tokenize(premise)
+ tokenized_hypothesis = self._tokenizer.tokenize(hypothesis)
+
+ fields = {}
+
+ premise_and_hypothesis = TextField(
+ self._tokenizer.add_special_tokens(tokenized_premise, tokenized_hypothesis),
+ )
+ fields["tokens"] = TextField(premise_and_hypothesis)
+
+ # make the metadata
+ metadata = {
+ "premise": premise,
+ "premise_tokens": tokenized_premise,
+ "hypothesis": hypothesis,
+ "hypothesis_tokens": tokenized_hypothesis,
+ "index": index,
+ }
+ if label:
+ fields["label"] = LabelField(label)
+ metadata["label"] = label
+
+ fields["metadata"] = MetadataField(metadata)
+
+ return Instance(fields)
+
+ @overrides
+ def apply_token_indexers(self, instance: Instance) -> None:
+ instance["tokens"].token_indexers = self._token_indexers
diff --git a/test_fixtures/rc/superglue_rte.jsonl b/test_fixtures/rc/superglue_rte.jsonl
new file mode 100644
index 000000000..6bb3aa07f
--- /dev/null
+++ b/test_fixtures/rc/superglue_rte.jsonl
@@ -0,0 +1,4 @@
+{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "label": "not_entailment", "idx": 0}
+{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "label": "entailment", "idx": 1}
+{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "label": "entailment", "idx": 2}
+{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "label": "entailment", "idx": 3}
diff --git a/test_fixtures/rc/superglue_rte_no_labels.jsonl b/test_fixtures/rc/superglue_rte_no_labels.jsonl
new file mode 100644
index 000000000..10218b6e9
--- /dev/null
+++ b/test_fixtures/rc/superglue_rte_no_labels.jsonl
@@ -0,0 +1,4 @@
+{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "idx": 0}
+{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "idx": 1}
+{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "idx": 2}
+{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "idx": 3}
diff --git a/tests/pair_classification/dataset_readers/transformer_superglue_rte_test.py b/tests/pair_classification/dataset_readers/transformer_superglue_rte_test.py
new file mode 100644
index 000000000..071858828
--- /dev/null
+++ b/tests/pair_classification/dataset_readers/transformer_superglue_rte_test.py
@@ -0,0 +1,39 @@
+from allennlp.common.params import Params
+from allennlp.common.util import ensure_list
+from allennlp.data import DatasetReader
+import pytest
+
+from allennlp_models.pair_classification import TransformerSuperGlueRteReader
+from tests import FIXTURES_ROOT
+
+
+class TestTransformerSuperGlueRteReader:
+ def test_read_from_file_superglue_rte(self):
+ reader = TransformerSuperGlueRteReader()
+ instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte.jsonl"))
+ assert len(instances) == 4
+
+ token_text = [t.text for t in instances[0].fields["tokens"].tokens]
+ assert token_text[:3] == ["", "No", "ĠWeapons"]
+ assert token_text[10:14] == [".", "", "", "Weapons"]
+ assert token_text[-3:] == ["ĠIraq", ".", ""]
+
+ assert instances[0].fields["label"].human_readable_repr() == "not_entailment"
+
+ assert instances[0].fields["metadata"]["label"] == "not_entailment"
+ assert instances[0].fields["metadata"]["index"] == 0
+
+ def test_read_from_file_superglue_rte_no_label(self):
+ reader = TransformerSuperGlueRteReader()
+ instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte_no_labels.jsonl"))
+ assert len(instances) == 4
+
+ token_text = [t.text for t in instances[0].fields["tokens"].tokens]
+ assert token_text[:3] == ["", "No", "ĠWeapons"]
+ assert token_text[10:14] == [".", "", "", "Weapons"]
+ assert token_text[-3:] == ["ĠIraq", ".", ""]
+
+ assert "label" not in instances[0].fields
+ assert "label" not in instances[0].fields["metadata"]
+
+ assert instances[0].fields["metadata"]["index"] == 0
diff --git a/training_config/pair_classification/superglue_rte_roberta.jsonnet b/training_config/pair_classification/superglue_rte_roberta.jsonnet
new file mode 100644
index 000000000..27b661caa
--- /dev/null
+++ b/training_config/pair_classification/superglue_rte_roberta.jsonnet
@@ -0,0 +1,67 @@
+local transformer_model = "roberta-large-mnli";
+local transformer_dim = 1024;
+
+local epochs = 20;
+local batch_size = 64;
+
+local gpu_batch_size = 4;
+local gradient_accumulation_steps = batch_size / gpu_batch_size;
+
+{
+ "dataset_reader":{
+ "type": "transformer_superglue_rte"
+ },
+ "train_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl",
+ "validation_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl",
+ "test_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/test.jsonl",
+ "model": {
+ "type": "basic_classifier",
+ "text_field_embedder": {
+ "token_embedders": {
+ "tokens": {
+ "type": "pretrained_transformer",
+ "model_name": transformer_model,
+ "max_length": 512
+ }
+ }
+ },
+ "seq2vec_encoder": {
+ "type": "cls_pooler",
+ "embedding_dim": transformer_dim,
+ },
+ "feedforward": {
+ "input_dim": transformer_dim,
+ "num_layers": 1,
+ "hidden_dims": transformer_dim,
+ "activations": "tanh"
+ },
+ "dropout": 0.1,
+ "namespace": "tags"
+ },
+ "data_loader": {
+ "shuffle": true,
+ "batch_size": gpu_batch_size
+ },
+ "trainer": {
+ "optimizer": {
+ "type": "huggingface_adamw",
+ "weight_decay": 0.01,
+ "parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]],
+ "lr": 1e-6,
+ "eps": 1e-8,
+ "correct_bias": true
+ },
+ "learning_rate_scheduler": {
+ "type": "linear_with_warmup",
+ "warmup_steps": 100
+ },
+ // "grad_norm": 1.0,
+ "num_epochs": epochs,
+ "num_gradient_accumulation_steps": gradient_accumulation_steps,
+ "patience": 3,
+ "validation_metric": "+accuracy",
+ },
+ "random_seed": 42,
+ "numpy_seed": 42,
+ "pytorch_seed": 42,
+}