Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Merge branch 'main' into OrderedRcExamples
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Mar 25, 2021
2 parents 4e2e8fc + 8aabfe5 commit 040bc9a
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Evaluating RC task card and associated LERC model card
- Compatibility with PyTorch 1.8
- Allows the order of examples in the task cards to be specified explicitly
- Dataset reader for SuperGLUE BoolQ

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp_models/classification/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from allennlp_models.classification.dataset_readers.stanford_sentiment_tree_bank import (
StanfordSentimentTreeBankDatasetReader,
)
from allennlp_models.classification.dataset_readers.boolq import BoolQDatasetReader
90 changes: 90 additions & 0 deletions allennlp_models/classification/dataset_readers/boolq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import json
import logging
from typing import Optional, Iterable, Dict

from allennlp.common.file_utils import cached_path
from overrides import overrides
from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Instance, Field
from allennlp.data.tokenizers import WhitespaceTokenizer
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.fields import TextField, LabelField

logger = logging.getLogger(__name__)


@DatasetReader.register("boolq")
class BoolQDatasetReader(DatasetReader):
"""
This DatasetReader is designed to read in the BoolQ data
for binary QA task. It returns a dataset of instances with the
following fields:
The output of `read` is a list of `Instance` s with the fields:
tokens : `TextField` and
label : `LabelField`
Registered as a `DatasetReader` with name "boolq".
# Parameters
tokenizer: `Tokenizer`, optional (default=`WhitespaceTokenizer()`)
Tokenizer to use to split the input sequences into words or other kinds of tokens.
token_indexers : `Dict[str, TokenIndexer]`, optional (default=`{"tokens": SingleIdTokenIndexer()}`)
We use this to define the input representation for the text. See :class:`TokenIndexer`.
"""

def __init__(
self, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None, **kwargs
):
super().__init__(
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs
)
self.tokenizer = tokenizer or WhitespaceTokenizer()
self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

@overrides
def _read(self, file_path) -> Iterable[Instance]:
file_path = cached_path(file_path, extract_archive=True)
with open(file_path) as f:
logger.info("Reading instances from lines in file at: %s", file_path)
for line in self.shard_iterable(f):
record = json.loads(line.strip())
yield self.text_to_instance(
passage=record.get("passage"),
question=record.get("question"),
label=record.get("label"),
)

@overrides
def text_to_instance( # type: ignore
self, passage: str, question: str, label: Optional[bool] = None
) -> Instance:
"""
We take the passage and the question as input, tokenize and concat them.
# Parameters
passage : `str`, required.
The passage in a given BoolQ record.
question : `str`, required.
The passage in a given BoolQ record.
label : `bool`, optional, (default = `None`).
The label for the passage and the question.
# Returns
An `Instance` containing the following fields:
tokens : `TextField`
The tokens in the concatenation of the passage and the question.
label : `LabelField`
The answer to the question.
"""
fields: Dict[str, Field] = {}

# 80% of the question length in the training set is less than 60, 512 - 4 - 60 = 448.
passage_tokens = self.tokenizer.tokenize(passage)[:448]
question_tokens = self.tokenizer.tokenize(question)[:60]

tokens = self.tokenizer.add_special_tokens(passage_tokens, question_tokens)
text_field = TextField(tokens)
fields["tokens"] = text_field

if label is not None:
label_field = LabelField(int(label), skip_indexing=True)
fields["label"] = label_field
return Instance(fields)

def apply_token_indexers(self, instance: Instance) -> None:
instance.fields["tokens"].token_indexers = self.token_indexers # type: ignore
5 changes: 5 additions & 0 deletions test_fixtures/classification/boolq.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{"question": "do iran and afghanistan speak the same language", "passage": "Persian language -- Persian (/\u02c8p\u025c\u02d0r\u0292\u0259n, -\u0283\u0259n/), also known by its endonym Farsi (\u0641\u0627\u0631\u0633\u06cc f\u0101rsi (f\u0252\u02d0\u027e\u02c8si\u02d0) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.", "idx": 0, "label": true}
{"question": "can you use oyster card at epsom station", "passage": "Epsom railway station -- Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).", "idx": 5, "label": false}
{"question": "can you use oyster card at epsom station", "passage": "Epsom railway station -- Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).", "idx": 5, "label": false}
{"question": "will there be a season 4 of da vinci's demons", "passage": "Da Vinci's Demons -- The series premiered in the United States on Starz on 12 April 2013, and its second season premiered on 22 March 2014. The series was renewed for a third season, which premiered on 24 October 2015. On 23 July 2015, Starz announced that the third season would be the show's last. However Goyer has left it open for a miniseries return.", "idx": 6, "label": false}
{"question": "is the federal court the same as the supreme court", "passage": "Federal judiciary of the United States -- The federal courts are composed of three levels of courts. The Supreme Court of the United States is the court of last resort. It is generally an appellate court that operates under discretionary review, which means that the Court can choose which cases to hear, by granting writs of certiorari. There is therefore generally no basic right of appeal that extends automatically all the way to the Supreme Court. In a few situations (like lawsuits between state governments or some cases between the federal government and a state) it sits as a court of original jurisdiction.", "idx": 7, "label": false}
82 changes: 82 additions & 0 deletions tests/classification/dataset_readers/boolq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
from allennlp.common.util import ensure_list
from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.data.token_indexers import PretrainedTransformerIndexer

from allennlp_models.classification import BoolQDatasetReader
from tests import FIXTURES_ROOT


class TestBoolqReader:
boolq_path = FIXTURES_ROOT / "classification" / "boolq.jsonl"

def test_boolq_dataset_reader_default_setting(self):
reader = BoolQDatasetReader()
instances = reader.read(self.boolq_path)
instances = ensure_list(instances)

assert len(instances) == 5

fields = instances[0].fields
assert [t.text for t in fields["tokens"].tokens][:5] == [
"Persian",
"language",
"--",
"Persian",
"(/ˈpɜːrʒən,",
]
assert fields["label"].label == 1

fields = instances[1].fields
assert [t.text for t in fields["tokens"].tokens][:5] == [
"Epsom",
"railway",
"station",
"--",
"Epsom",
]
assert fields["label"].label == 0

def test_boolq_dataset_reader_roberta_setting(self):
reader = BoolQDatasetReader(
tokenizer=PretrainedTransformerTokenizer("roberta-base", add_special_tokens=False),
token_indexers={"tokens": PretrainedTransformerIndexer("roberta-base")},
)
instances = reader.read(self.boolq_path)
instances = ensure_list(instances)

assert len(instances) == 5

fields = instances[0].fields
assert [t.text for t in fields["tokens"].tokens][:5] == [
"<s>",
"Pers",
"ian",
"Ġlanguage",
"Ġ--",
]
assert [t.text for t in fields["tokens"].tokens][-5:] == [
"Ġspeak",
"Ġthe",
"Ġsame",
"Ġlanguage",
"</s>",
]
assert fields["label"].label == 1

fields = instances[1].fields
assert [t.text for t in fields["tokens"].tokens][:5] == [
"<s>",
"E",
"ps",
"om",
"Ġrailway",
]
assert [t.text for t in fields["tokens"].tokens][-5:] == [
"Ġe",
"ps",
"om",
"Ġstation",
"</s>",
]
assert fields["label"].label == 0
62 changes: 62 additions & 0 deletions training_config/classification/boolq_roberta.jsonnet
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
local transformer_model = "roberta-large";
local transformer_dim = 1024;

{
"dataset_reader":{
"type": "boolq",
"token_indexers": {
"tokens": {
"type": "pretrained_transformer",
"model_name": transformer_model,
}
},
"tokenizer": {
"type": "pretrained_transformer",
"model_name": transformer_model,
}
},
"train_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/train.jsonl",
"validation_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/val.jsonl",
"test_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/test.jsonl",
"model": {
"type": "basic_classifier",
"text_field_embedder": {
"token_embedders": {
"tokens": {
"type": "pretrained_transformer",
"model_name": transformer_model,
}
}
},
"seq2vec_encoder": {
"type": "bert_pooler",
"pretrained_model": transformer_model,
"dropout": 0.1,
},
"namespace": "tags",
"num_labels": 2,
},
"data_loader": {
"batch_sampler": {
"type": "bucket",
"sorting_keys": ["tokens"],
"batch_size" : 2
}
},
"trainer": {
"num_epochs": 10,
"validation_metric": "+accuracy",
"learning_rate_scheduler": {
"type": "slanted_triangular",
"num_epochs": 10,
"num_steps_per_epoch": 3088,
"cut_frac": 0.06
},
"optimizer": {
"type": "huggingface_adamw",
"lr": 1e-5,
"weight_decay": 0.1,
},
"num_gradient_accumulation_steps": 16,
},
}

0 comments on commit 040bc9a

Please sign in to comment.