This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding a dataset reader and training config for RoBERTa on SuperGLUE RTE Co-authored-by: Dirk Groeneveld <[email protected]>
- Loading branch information
1 parent
419bc90
commit c733f83
Showing
9 changed files
with
306 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
allennlp_models/modelcards/pair-classification-roberta-rte.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
{ | ||
"id": "pair-classification-roberta-rte", | ||
"registered_model_name": "roberta-rte", | ||
"registered_predictor_name": null, | ||
"display_name": "RoBERTa RTE", | ||
"task_id": "pair_classification", | ||
"model_details": { | ||
"description": "The model implements a pair classification model patterned after the proposed model in [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding (Devlin et al, 2018)](https://api.semanticscholar.org/CorpusID:52967399), fine-tuned on the MultiNLI corpus. It predicts labels with a linear layer on top of word piece embeddings.", | ||
"short_description": "A pair classification model patterned after the proposed model in Devlin et al, fine-tuned on the SuperGLUE RTE corpus", | ||
"developed_by": "Devlin et al", | ||
"contributed_by": "Jacob Morrison", | ||
"date": "2021-04-09", | ||
"version": "1", | ||
"model_type": "RoBERTa", | ||
"paper": { | ||
"citation": "\n@article{Liu2019RoBERTaAR,\ntitle={RoBERTa: A Robustly Optimized BERT Pretraining Approach},\nauthor={Y. Liu and Myle Ott and Naman Goyal and Jingfei Du and Mandar Joshi and Danqi Chen and Omer Levy and M. Lewis and L. Zettlemoyer and V. Stoyanov},\njournal={ArXiv},\nyear={2019},\nvolume={abs/1907.11692}}\n", | ||
"title": "RoBERTa: A Robustly Optimized BERT Pretraining Approach", | ||
"url": "https://api.semanticscholar.org/CorpusID:198953378" | ||
}, | ||
"license": null, | ||
"contact": "[email protected]" | ||
}, | ||
"intended_use": { | ||
"primary_uses": null, | ||
"primary_users": null, | ||
"out_of_scope_use_cases": null | ||
}, | ||
"factors": { | ||
"relevant_factors": null, | ||
"evaluation_factors": null | ||
}, | ||
"metrics": { | ||
"model_performance_measures": "Accuracy", | ||
"decision_thresholds": null, | ||
"variation_approaches": null | ||
}, | ||
"evaluation_data": { | ||
"dataset": { | ||
"name": "SuperGLUE Recognizing Textual Entailment validation set", | ||
"url": "https://super.gluebenchmark.com/tasks", | ||
"processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl" | ||
}, | ||
"motivation": null, | ||
"preprocessing": null | ||
}, | ||
"training_data": { | ||
"dataset": { | ||
"name": "SuperGLUE Recognizing Textual Entailment training set", | ||
"url": "https://super.gluebenchmark.com/tasks", | ||
"processed_url": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl" | ||
}, | ||
"motivation": null, | ||
"preprocessing": null | ||
}, | ||
"quantitative_analyses": { | ||
"unitary_results": "Accuracy: 89.9% on the SuperGLUE RTE validation dataset.", | ||
"intersectional_results": null | ||
}, | ||
"model_caveats_and_recommendations": { | ||
"caveats_and_recommendations": null | ||
}, | ||
"model_ethical_considerations": { | ||
"ethical_considerations": null | ||
}, | ||
"model_usage": { | ||
"archive_file": "superglue-rte-roberta.2021-04-09.tar.gz", | ||
"training_config": "pair-classification/superglue_rte_roberta.jsonnet", | ||
"install_instructions": "pip install allennlp==2.3.1 allennlp-models==2.3.1" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
117 changes: 117 additions & 0 deletions
117
allennlp_models/pair_classification/dataset_readers/transformer_superglue_rte.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import logging | ||
from typing import Any, Dict | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.data.fields import MetadataField, TextField, LabelField | ||
from allennlp.common.file_utils import cached_path | ||
from allennlp.data.dataset_readers.dataset_reader import DatasetReader | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.token_indexers import PretrainedTransformerIndexer | ||
from allennlp.data.tokenizers import PretrainedTransformerTokenizer | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DatasetReader.register("transformer_superglue_rte") | ||
class TransformerSuperGlueRteReader(DatasetReader): | ||
""" | ||
Dataset reader for the SuperGLUE Recognizing Textual Entailment task, to be used with a transformer | ||
model such as RoBERTa. The dataset is in the JSON Lines format. | ||
It will generate `Instances` with the following fields: | ||
* `tokens`, a `TextField` that contains the concatenation of premise and hypothesis, | ||
* `label`, a `LabelField` containing the label, if one exists. | ||
* `metadata`, a `MetadataField` that stores the instance's index in the file, the original premise, | ||
the original hypothesis, both of these in tokenized form, and the gold label, accessible as | ||
`metadata['index']`, `metadata['premise']`, `metadata['hypothesis']`, `metadata['tokens']`, | ||
and `metadata['label']`. | ||
# Parameters | ||
type : `str`, optional (default=`'roberta-base'`) | ||
This reader chooses tokenizer according to this setting. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
transformer_model_name: str = "roberta-base", | ||
tokenizer_kwargs: Dict[str, Any] = None, | ||
**kwargs | ||
) -> None: | ||
super().__init__( | ||
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs | ||
) | ||
self._tokenizer = PretrainedTransformerTokenizer( | ||
transformer_model_name, | ||
add_special_tokens=False, | ||
tokenizer_kwargs=tokenizer_kwargs, | ||
) | ||
self._token_indexers = { | ||
"tokens": PretrainedTransformerIndexer( | ||
transformer_model_name, tokenizer_kwargs=tokenizer_kwargs, max_length=512 | ||
) | ||
} | ||
|
||
@overrides | ||
def _read(self, file_path: str): | ||
# if `file_path` is a URL, redirect to the cache | ||
file_path = cached_path(file_path, extract_archive=True) | ||
|
||
logger.info("Reading file at %s", file_path) | ||
yielded_relation_count = 0 | ||
from allennlp.common.file_utils import json_lines_from_file | ||
|
||
for relation in self.shard_iterable(json_lines_from_file(file_path)): | ||
premise = relation["premise"] | ||
hypothesis = relation["hypothesis"] | ||
if "label" in relation: | ||
label = relation["label"] | ||
else: | ||
label = None | ||
index = relation["idx"] | ||
|
||
# todo: see if we even need this to be in a separate method | ||
instance = self.text_to_instance(index, label, premise, hypothesis) | ||
|
||
yield instance | ||
yielded_relation_count += 1 | ||
|
||
@overrides | ||
def text_to_instance( | ||
self, | ||
index: int, | ||
label: str, | ||
premise: str, | ||
hypothesis: str, | ||
) -> Instance: | ||
tokenized_premise = self._tokenizer.tokenize(premise) | ||
tokenized_hypothesis = self._tokenizer.tokenize(hypothesis) | ||
|
||
fields = {} | ||
|
||
premise_and_hypothesis = TextField( | ||
self._tokenizer.add_special_tokens(tokenized_premise, tokenized_hypothesis), | ||
) | ||
fields["tokens"] = TextField(premise_and_hypothesis) | ||
|
||
# make the metadata | ||
metadata = { | ||
"premise": premise, | ||
"premise_tokens": tokenized_premise, | ||
"hypothesis": hypothesis, | ||
"hypothesis_tokens": tokenized_hypothesis, | ||
"index": index, | ||
} | ||
if label: | ||
fields["label"] = LabelField(label) | ||
metadata["label"] = label | ||
|
||
fields["metadata"] = MetadataField(metadata) | ||
|
||
return Instance(fields) | ||
|
||
@overrides | ||
def apply_token_indexers(self, instance: Instance) -> None: | ||
instance["tokens"].token_indexers = self._token_indexers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "label": "not_entailment", "idx": 0} | ||
{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "label": "entailment", "idx": 1} | ||
{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "label": "entailment", "idx": 2} | ||
{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "label": "entailment", "idx": 3} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
{"premise": "No Weapons of Mass Destruction Found in Iraq Yet.", "hypothesis": "Weapons of Mass Destruction Found in Iraq.", "idx": 0} | ||
{"premise": "A place of sorrow, after Pope John Paul II died, became a place of celebration, as Roman Catholic faithful gathered in downtown Chicago to mark the installation of new Pope Benedict XVI.", "hypothesis": "Pope Benedict XVI is the new leader of the Roman Catholic Church.", "idx": 1} | ||
{"premise": "Herceptin was already approved to treat the sickest breast cancer patients, and the company said, Monday, it will discuss with federal regulators the possibility of prescribing the drug for more breast cancer patients.", "hypothesis": "Herceptin can be used to treat breast cancer.", "idx": 2} | ||
{"premise": "Judie Vivian, chief executive at ProMedica, a medical service company that helps sustain the 2-year-old Vietnam Heart Institute in Ho Chi Minh City (formerly Saigon), said that so far about 1,500 children have received treatment.", "hypothesis": "The previous name of Ho Chi Minh City was Saigon.", "idx": 3} |
39 changes: 39 additions & 0 deletions
39
tests/pair_classification/dataset_readers/transformer_superglue_rte_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from allennlp.common.params import Params | ||
from allennlp.common.util import ensure_list | ||
from allennlp.data import DatasetReader | ||
import pytest | ||
|
||
from allennlp_models.pair_classification import TransformerSuperGlueRteReader | ||
from tests import FIXTURES_ROOT | ||
|
||
|
||
class TestTransformerSuperGlueRteReader: | ||
def test_read_from_file_superglue_rte(self): | ||
reader = TransformerSuperGlueRteReader() | ||
instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte.jsonl")) | ||
assert len(instances) == 4 | ||
|
||
token_text = [t.text for t in instances[0].fields["tokens"].tokens] | ||
assert token_text[:3] == ["<s>", "No", "ĠWeapons"] | ||
assert token_text[10:14] == [".", "</s>", "</s>", "Weapons"] | ||
assert token_text[-3:] == ["ĠIraq", ".", "</s>"] | ||
|
||
assert instances[0].fields["label"].human_readable_repr() == "not_entailment" | ||
|
||
assert instances[0].fields["metadata"]["label"] == "not_entailment" | ||
assert instances[0].fields["metadata"]["index"] == 0 | ||
|
||
def test_read_from_file_superglue_rte_no_label(self): | ||
reader = TransformerSuperGlueRteReader() | ||
instances = ensure_list(reader.read(FIXTURES_ROOT / "rc" / "superglue_rte_no_labels.jsonl")) | ||
assert len(instances) == 4 | ||
|
||
token_text = [t.text for t in instances[0].fields["tokens"].tokens] | ||
assert token_text[:3] == ["<s>", "No", "ĠWeapons"] | ||
assert token_text[10:14] == [".", "</s>", "</s>", "Weapons"] | ||
assert token_text[-3:] == ["ĠIraq", ".", "</s>"] | ||
|
||
assert "label" not in instances[0].fields | ||
assert "label" not in instances[0].fields["metadata"] | ||
|
||
assert instances[0].fields["metadata"]["index"] == 0 |
67 changes: 67 additions & 0 deletions
67
training_config/pair_classification/superglue_rte_roberta.jsonnet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
local transformer_model = "roberta-large-mnli"; | ||
local transformer_dim = 1024; | ||
|
||
local epochs = 20; | ||
local batch_size = 64; | ||
|
||
local gpu_batch_size = 4; | ||
local gradient_accumulation_steps = batch_size / gpu_batch_size; | ||
|
||
{ | ||
"dataset_reader":{ | ||
"type": "transformer_superglue_rte" | ||
}, | ||
"train_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/train.jsonl", | ||
"validation_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/val.jsonl", | ||
"test_data_path": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/RTE.zip!RTE/test.jsonl", | ||
"model": { | ||
"type": "basic_classifier", | ||
"text_field_embedder": { | ||
"token_embedders": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": transformer_model, | ||
"max_length": 512 | ||
} | ||
} | ||
}, | ||
"seq2vec_encoder": { | ||
"type": "cls_pooler", | ||
"embedding_dim": transformer_dim, | ||
}, | ||
"feedforward": { | ||
"input_dim": transformer_dim, | ||
"num_layers": 1, | ||
"hidden_dims": transformer_dim, | ||
"activations": "tanh" | ||
}, | ||
"dropout": 0.1, | ||
"namespace": "tags" | ||
}, | ||
"data_loader": { | ||
"shuffle": true, | ||
"batch_size": gpu_batch_size | ||
}, | ||
"trainer": { | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"weight_decay": 0.01, | ||
"parameter_groups": [[["bias", "LayerNorm\\.weight", "layer_norm\\.weight"], {"weight_decay": 0}]], | ||
"lr": 1e-6, | ||
"eps": 1e-8, | ||
"correct_bias": true | ||
}, | ||
"learning_rate_scheduler": { | ||
"type": "linear_with_warmup", | ||
"warmup_steps": 100 | ||
}, | ||
// "grad_norm": 1.0, | ||
"num_epochs": epochs, | ||
"num_gradient_accumulation_steps": gradient_accumulation_steps, | ||
"patience": 3, | ||
"validation_metric": "+accuracy", | ||
}, | ||
"random_seed": 42, | ||
"numpy_seed": 42, | ||
"pytorch_seed": 42, | ||
} |