Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Roberta data reader #247

Merged
merged 45 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ecd82f9
Adding span probabilities to output_dict
jacob-morrison Apr 6, 2021
6403395
Fixing bracket placement
jacob-morrison Apr 6, 2021
1e59646
Making test variable names more specific
jacob-morrison Apr 6, 2021
9c2ac75
Running formatter
jacob-morrison Apr 6, 2021
bd171d6
Updating changelog
jacob-morrison Apr 6, 2021
6dc59e0
Fixing probability and updating test
jacob-morrison Apr 6, 2021
4519269
Fixing test
jacob-morrison Apr 6, 2021
506f37b
Adding initial dataset reader for SuperGLUE RTE
jacob-morrison Apr 7, 2021
a0881c3
Merge branch 'main' of https://github.com/allenai/allennlp-models int…
jacob-morrison Apr 7, 2021
bd14a42
Updating config and comment
jacob-morrison Apr 7, 2021
9d1eaa5
Removing unnecessary variables
jacob-morrison Apr 7, 2021
9cecb0a
updates
jacob-morrison Apr 7, 2021
aca7e7b
fix
jacob-morrison Apr 7, 2021
983a9a7
Updating requirements.txt + some formatting stuff
jacob-morrison Apr 7, 2021
50789f6
Updating changelog
jacob-morrison Apr 7, 2021
61a927b
Removing unused dependencies
jacob-morrison Apr 7, 2021
27e7f3b
Attempting to extract the archive
jacob-morrison Apr 7, 2021
4b3f8f6
trying new file path
jacob-morrison Apr 8, 2021
84268fd
Adding tokenizer indexer stuff
jacob-morrison Apr 8, 2021
15cc637
Removing extra bit
jacob-morrison Apr 8, 2021
da55d31
adding import
jacob-morrison Apr 8, 2021
84ca375
Updating field name
jacob-morrison Apr 8, 2021
f2eded9
Adding max length field
jacob-morrison Apr 8, 2021
3c1c8c4
Lowering # epochs
jacob-morrison Apr 8, 2021
67da005
Changing back to 10
jacob-morrison Apr 8, 2021
44333c6
Switching to lower epochs again
jacob-morrison Apr 8, 2021
97c9993
trying mnli hyperparameters
jacob-morrison Apr 8, 2021
82f23df
piqa parameters
jacob-morrison Apr 8, 2021
aa71cbc
updating model file
jacob-morrison Apr 9, 2021
6f3803a
back to basic classifier
jacob-morrison Apr 9, 2021
b772561
Fiddling with piqa params for superglue rte
jacob-morrison Apr 9, 2021
fb1a26e
Trying a different learning rate
jacob-morrison Apr 9, 2021
7d8b4d5
trying a finetuned roberta
jacob-morrison Apr 9, 2021
d2b663e
Moving from rc/ to pair_classification/
jacob-morrison Apr 9, 2021
39df3e5
Updating model card
jacob-morrison Apr 10, 2021
833c536
updating readme
jacob-morrison Apr 10, 2021
b5c5bc3
fixing typo
jacob-morrison Apr 10, 2021
a4ba192
Merge branch 'main' into roberta-data-reader
jacob-morrison Apr 10, 2021
27fa6bd
Merge branch 'main' into roberta-data-reader
dirkgr Apr 13, 2021
803e952
Making changes
jacob-morrison Apr 15, 2021
86868f2
Merge branch 'main' into roberta-data-reader
jacob-morrison Apr 15, 2021
15ce524
Updating version number
jacob-morrison Apr 15, 2021
25cfb52
Merge branch 'roberta-data-reader' of https://github.com/allenai/alle…
jacob-morrison Apr 15, 2021
d778bb5
Update readme
jacob-morrison Apr 15, 2021
9521e6a
reformat
jacob-morrison Apr 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added script that produces the coref training data.
- Added tests for using `allennlp predict` on multitask models.

- Added reader and training config for RoBERTa on SuperGLUE's Recognizing Textual Entailment task

## [v2.2.0](https://github.com/allenai/allennlp-models/releases/tag/v2.2.0) - 2021-03-26

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ Here is a list of pre-trained models currently available.
- [`pair-classification-decomposable-attention-elmo`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-decomposable-attention-elmo.json) - The decomposable attention model (Parikh et al, 2017) combined with ELMo embeddings trained on SNLI.
- [`pair-classification-esim`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-esim.json) - Enhanced LSTM trained on SNLI.
- [`pair-classification-roberta-mnli`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-mnli.json) - RoBERTa finetuned on MNLI.
- [`pair-classification-roberta-rte`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-rte.json) - A pair classification model patterned after the proposed model in Devlin et al, fine-tuned on the SuperGLUE RTE corpus
- [`pair-classification-roberta-snli`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/pair-classification-roberta-snli.json) - RoBERTa finetuned on SNLI.
- [`rc-bidaf-elmo`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/rc-bidaf-elmo.json) - BiDAF model with ELMo embeddings instead of GloVe.
- [`rc-bidaf`](https://github.com/allenai/allennlp-models/tree/main/allennlp_models/modelcards/rc-bidaf.json) - BiDAF model with GloVe embeddings.
Expand Down
70 changes: 70 additions & 0 deletions allennlp_models/modelcards/pair-classification-roberta-rte.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
{
"id": "pair-classification-roberta-rte",
"registered_model_name": "roberta-rte",
"registered_predictor_name": null,
"display_name": "RoBERTa RTE",
"task_id": "pair_classification",
"model_details": {
"description": "The model implements a pair classification model patterned after the proposed model in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al, 2018)](https://api.semanticscholar.org/CorpusID:52967399), fine-tuned on the MultiNLI corpus. It predicts labels with a linear layer on top of word piece embeddings.",
"short_description": "A pair classification model patterned after the proposed model in Devlin et al, fine-tuned on the SuperGLUE RTE corpus",
"developed_by": "Devlin et al",
"contributed_by": "Jacob Morrison",
"date": "2021-04-09",
"version": "1",
"model_type": "RoBERTa",
"paper": {
"citation": "\n@article{Liu2019RoBERTaAR,\ntitle={RoBERTa: A Robustly Optimized BERT Pretraining Approach},\nauthor={Y. Liu and Myle Ott and Naman Goyal and Jingfei Du and Mandar Joshi and Danqi Chen and Omer Levy and M. Lewis and L. Zettlemoyer and V. Stoyanov},\njournal={ArXiv},\nyear={2019},\nvolume={abs/1907.11692}}\n",
"title": "RoBERTa: A Robustly Optimized BERT Pretraining Approach",
"url": "https://api.semanticscholar.org/CorpusID:198953378"
},
"license": null,
"contact": "[email protected]"
},
"intended_use": {
"primary_uses": null,
"primary_users": null,
"out_of_scope_use_cases": null
},
"factors": {
"relevant_factors": null,
"evaluation_factors": null
},
"metrics": {
"model_performance_measures": "Accuracy",
"decision_thresholds": null,
"variation_approaches": null
},
"evaluation_data": {
"dataset": {
"name": "SuperGLUE Recognizing Textual Entailment validation set",
"url": "https://super.gluebenchmark.com/tasks",
"processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl"
},
"motivation": null,
"preprocessing": null
},
"training_data": {
"dataset": {
"name": "SuperGLUE Recognizing Textual Entailment training set",
"url": "https://super.gluebenchmark.com/tasks",
"processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl"
},
"motivation": null,
"preprocessing": null
},
"quantitative_analyses": {
"unitary_results": "Accuracy: 89.9% on the SuperGLUE RTE validation dataset.",
"intersectional_results": null
},
"model_caveats_and_recommendations": {
"caveats_and_recommendations": null
},
"model_ethical_considerations": {
"ethical_considerations": null
},
"model_usage": {
"archive_file": "superglue-rte-roberta.2021-04-09.tar.gz",
"training_config": "pair-classification/superglue_rte_roberta.jsonnet",
"install_instructions": "pip install allennlp==2.3.1 allennlp-models==2.3.1"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
QuoraParaphraseDatasetReader,
)
from allennlp_models.pair_classification.dataset_readers.snli import SnliReader
from allennlp_models.pair_classification.dataset_readers.transformer_superglue_rte import (
TransformerSuperGlueRteReader,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging
from typing import Any, Dict

from overrides import overrides

from allennlp.data.fields import MetadataField, TextField, LabelField
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import PretrainedTransformerIndexer
from allennlp.data.tokenizers import PretrainedTransformerTokenizer

logger = logging.getLogger(__name__)


@DatasetReader.register("transformer_superglue_rte")
class TransformerSuperGlueRteReader(DatasetReader):
"""
Dataset reader for the SuperGLUE Recognizing Textual Entailment task, to be used with a transformer
model such as RoBERTa. The dataset is in the JSON Lines format.

It will generate `Instances` with the following fields:

* `tokens`, a `TextField` that contains the concatenation of premise and hypothesis,
* `label`, a `LabelField` containing the label, if one exists.
* `metadata`, a `MetadataField` that stores the instance's index in the file, the original premise,
the original hypothesis, both of these in tokenized form, and the gold label, accessible as
`metadata['index']`, `metadata['premise']`, `metadata['hypothesis']`, `metadata['tokens']`,
and `metadata['label']`.

# Parameters

type : `str`, optional (default=`'roberta-base'`)
This reader chooses tokenizer according to this setting.
"""

def __init__(
self,
transformer_model_name: str = "roberta-base",
tokenizer_kwargs: Dict[str, Any] = None,
**kwargs
) -> None:
super().__init__(
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
)
self._tokenizer = PretrainedTransformerTokenizer(
transformer_model_name,
add_special_tokens=False,
tokenizer_kwargs=tokenizer_kwargs,
)
self._token_indexers = {
"tokens": PretrainedTransformerIndexer(
transformer_model_name, tokenizer_kwargs=tokenizer_kwargs, max_length=512
)
}

@overrides
def _read(self, file_path: str):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path, extract_archive=True)

logger.info("Reading file at %s", file_path)
yielded_relation_count = 0
from allennlp.common.file_utils import json_lines_from_file

for relation in self.shard_iterable(json_lines_from_file(file_path)):
premise = relation["premise"]
hypothesis = relation["hypothesis"]
if "label" in relation:
label = relation["label"]
else:
label = None
index = relation["idx"]

# todo: see if we even need this to be in a separate method
instance = self.text_to_instance(index, label, premise, hypothesis)

yield instance
yielded_relation_count += 1

@overrides
def text_to_instance(
self,
index: int,
label: str,
premise: str,
hypothesis: str,
) -> Instance:
tokenized_premise = self._tokenizer.tokenize(premise)
tokenized_hypothesis = self._tokenizer.tokenize(hypothesis)

fields = {}

premise_and_hypothesis = TextField(
self._tokenizer.add_special_tokens(tokenized_premise, tokenized_hypothesis),
)
fields["tokens"] = TextField(premise_and_hypothesis)

# make the metadata
metadata = {
"premise": premise,
"premise_tokens": tokenized_premise,
"hypothesis": hypothesis,
"hypothesis_tokens": tokenized_hypothesis,
"index": index,
}
if label:
fields["label"] = LabelField(label)
metadata["label"] = label

fields["metadata"] = MetadataField(metadata)

return Instance(fields)

@overrides
def apply_token_indexers(self, instance: Instance) -> None:
instance["tokens"].token_indexers = self._token_indexers
4 changes: 4 additions & 0 deletions test_fixtures/rc/superglue_rte.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "label": "not_entailment", "idx": 0}
{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "label": "entailment", "idx": 1}
{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "label": "entailment", "idx": 2}
{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "label": "entailment", "idx": 3}
4 changes: 4 additions & 0 deletions test_fixtures/rc/superglue_rte_no_labels.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "idx": 0}
{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "idx": 1}
{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "idx": 2}
{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "idx": 3}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from allennlp.common.params import Params
from allennlp.common.util import ensure_list
from allennlp.data import DatasetReader
import pytest

from allennlp_models.pair_classification import TransformerSuperGlueRteReader
from tests import FIXTURES_ROOT


class TestTransformerSuperGlueRteReader:
def test_read_from_file_superglue_rte(self):
reader = TransformerSuperGlueRteReader()
instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte.jsonl"))
assert len(instances) == 4

token_text = [t.text for t in instances[0].fields["tokens"].tokens]
assert token_text[:3] == ["<s>", "No", "ĠWeapons"]
assert token_text[10:14] == [".", "</s>", "</s>", "Weapons"]
assert token_text[-3:] == ["ĠIraq", ".", "</s>"]

assert instances[0].fields["label"].human_readable_repr() == "not_entailment"

assert instances[0].fields["metadata"]["label"] == "not_entailment"
assert instances[0].fields["metadata"]["index"] == 0

def test_read_from_file_superglue_rte_no_label(self):
reader = TransformerSuperGlueRteReader()
instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte_no_labels.jsonl"))
assert len(instances) == 4

token_text = [t.text for t in instances[0].fields["tokens"].tokens]
assert token_text[:3] == ["<s>", "No", "ĠWeapons"]
assert token_text[10:14] == [".", "</s>", "</s>", "Weapons"]
assert token_text[-3:] == ["ĠIraq", ".", "</s>"]

assert "label" not in instances[0].fields
assert "label" not in instances[0].fields["metadata"]

assert instances[0].fields["metadata"]["index"] == 0
67 changes: 67 additions & 0 deletions training_config/pair_classification/superglue_rte_roberta.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
local transformer_model = "roberta-large-mnli";
local transformer_dim = 1024;

local epochs = 20;
local batch_size = 64;

local gpu_batch_size = 4;
local gradient_accumulation_steps = batch_size / gpu_batch_size;

{
"dataset_reader":{
"type": "transformer_superglue_rte"
},
"train_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl",
"validation_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl",
"test_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/test.jsonl",
"model": {
"type": "basic_classifier",
"text_field_embedder": {
"token_embedders": {
"tokens": {
"type": "pretrained_transformer",
"model_name": transformer_model,
"max_length": 512
}
}
},
"seq2vec_encoder": {
"type": "cls_pooler",
"embedding_dim": transformer_dim,
},
"feedforward": {
"input_dim": transformer_dim,
"num_layers": 1,
"hidden_dims": transformer_dim,
"activations": "tanh"
},
"dropout": 0.1,
"namespace": "tags"
},
"data_loader": {
"shuffle": true,
"batch_size": gpu_batch_size
},
"trainer": {
"optimizer": {
"type": "huggingface_adamw",
"weight_decay": 0.01,
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]],
"lr": 1e-6,
"eps": 1e-8,
"correct_bias": true
},
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps": 100
jacob-morrison marked this conversation as resolved.
Show resolved Hide resolved
},
// "grad_norm": 1.0,
"num_epochs": epochs,
"num_gradient_accumulation_steps": gradient_accumulation_steps,
"patience": 3,
"validation_metric": "+accuracy",
},
"random_seed": 42,
"numpy_seed": 42,
"pytorch_seed": 42,
}