This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into OrderedRcExamples
- Loading branch information
Showing
6 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from allennlp_models.classification.dataset_readers.stanford_sentiment_tree_bank import ( | ||
StanfordSentimentTreeBankDatasetReader, | ||
) | ||
from allennlp_models.classification.dataset_readers.boolq import BoolQDatasetReader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import json | ||
import logging | ||
from typing import Optional, Iterable, Dict | ||
|
||
from allennlp.common.file_utils import cached_path | ||
from overrides import overrides | ||
from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Instance, Field | ||
from allennlp.data.tokenizers import WhitespaceTokenizer | ||
from allennlp.data.token_indexers import SingleIdTokenIndexer | ||
from allennlp.data.fields import TextField, LabelField | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DatasetReader.register("boolq") | ||
class BoolQDatasetReader(DatasetReader): | ||
""" | ||
This DatasetReader is designed to read in the BoolQ data | ||
for binary QA task. It returns a dataset of instances with the | ||
following fields: | ||
The output of `read` is a list of `Instance` s with the fields: | ||
tokens : `TextField` and | ||
label : `LabelField` | ||
Registered as a `DatasetReader` with name "boolq". | ||
# Parameters | ||
tokenizer: `Tokenizer`, optional (default=`WhitespaceTokenizer()`) | ||
Tokenizer to use to split the input sequences into words or other kinds of tokens. | ||
token_indexers : `Dict[str, TokenIndexer]`, optional (default=`{"tokens": SingleIdTokenIndexer()}`) | ||
We use this to define the input representation for the text. See :class:`TokenIndexer`. | ||
""" | ||
|
||
def __init__( | ||
self, tokenizer: Tokenizer = None, token_indexers: Dict[str, TokenIndexer] = None, **kwargs | ||
): | ||
super().__init__( | ||
manual_distributed_sharding=True, manual_multiprocess_sharding=True, **kwargs | ||
) | ||
self.tokenizer = tokenizer or WhitespaceTokenizer() | ||
self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} | ||
|
||
@overrides | ||
def _read(self, file_path) -> Iterable[Instance]: | ||
file_path = cached_path(file_path, extract_archive=True) | ||
with open(file_path) as f: | ||
logger.info("Reading instances from lines in file at: %s", file_path) | ||
for line in self.shard_iterable(f): | ||
record = json.loads(line.strip()) | ||
yield self.text_to_instance( | ||
passage=record.get("passage"), | ||
question=record.get("question"), | ||
label=record.get("label"), | ||
) | ||
|
||
@overrides | ||
def text_to_instance( # type: ignore | ||
self, passage: str, question: str, label: Optional[bool] = None | ||
) -> Instance: | ||
""" | ||
We take the passage and the question as input, tokenize and concat them. | ||
# Parameters | ||
passage : `str`, required. | ||
The passage in a given BoolQ record. | ||
question : `str`, required. | ||
The passage in a given BoolQ record. | ||
label : `bool`, optional, (default = `None`). | ||
The label for the passage and the question. | ||
# Returns | ||
An `Instance` containing the following fields: | ||
tokens : `TextField` | ||
The tokens in the concatenation of the passage and the question. | ||
label : `LabelField` | ||
The answer to the question. | ||
""" | ||
fields: Dict[str, Field] = {} | ||
|
||
# 80% of the question length in the training set is less than 60, 512 - 4 - 60 = 448. | ||
passage_tokens = self.tokenizer.tokenize(passage)[:448] | ||
question_tokens = self.tokenizer.tokenize(question)[:60] | ||
|
||
tokens = self.tokenizer.add_special_tokens(passage_tokens, question_tokens) | ||
text_field = TextField(tokens) | ||
fields["tokens"] = text_field | ||
|
||
if label is not None: | ||
label_field = LabelField(int(label), skip_indexing=True) | ||
fields["label"] = label_field | ||
return Instance(fields) | ||
|
||
def apply_token_indexers(self, instance: Instance) -> None: | ||
instance.fields["tokens"].token_indexers = self.token_indexers # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{"question": "do iran and afghanistan speak the same language", "passage": "Persian language -- Persian (/\u02c8p\u025c\u02d0r\u0292\u0259n, -\u0283\u0259n/), also known by its endonym Farsi (\u0641\u0627\u0631\u0633\u06cc f\u0101rsi (f\u0252\u02d0\u027e\u02c8si\u02d0) ( listen)), is one of the Western Iranian languages within the Indo-Iranian branch of the Indo-European language family. It is primarily spoken in Iran, Afghanistan (officially known as Dari since 1958), and Tajikistan (officially known as Tajiki since the Soviet era), and some other regions which historically were Persianate societies and considered part of Greater Iran. It is written in the Persian alphabet, a modified variant of the Arabic script, which itself evolved from the Aramaic alphabet.", "idx": 0, "label": true} | ||
{"question": "can you use oyster card at epsom station", "passage": "Epsom railway station -- Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).", "idx": 5, "label": false} | ||
{"question": "can you use oyster card at epsom station", "passage": "Epsom railway station -- Epsom railway station serves the town of Epsom in Surrey. It is located off Waterloo Road and is less than two minutes' walk from the High Street. It is not in the London Oyster card zone unlike Epsom Downs or Tattenham Corner stations. The station building was replaced in 2012/2013 with a new building with apartments above the station (see end of article).", "idx": 5, "label": false} | ||
{"question": "will there be a season 4 of da vinci's demons", "passage": "Da Vinci's Demons -- The series premiered in the United States on Starz on 12 April 2013, and its second season premiered on 22 March 2014. The series was renewed for a third season, which premiered on 24 October 2015. On 23 July 2015, Starz announced that the third season would be the show's last. However Goyer has left it open for a miniseries return.", "idx": 6, "label": false} | ||
{"question": "is the federal court the same as the supreme court", "passage": "Federal judiciary of the United States -- The federal courts are composed of three levels of courts. The Supreme Court of the United States is the court of last resort. It is generally an appellate court that operates under discretionary review, which means that the Court can choose which cases to hear, by granting writs of certiorari. There is therefore generally no basic right of appeal that extends automatically all the way to the Supreme Court. In a few situations (like lawsuits between state governments or some cases between the federal government and a state) it sits as a court of original jurisdiction.", "idx": 7, "label": false} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# -*- coding: utf-8 -*- | ||
from allennlp.common.util import ensure_list | ||
from allennlp.data.tokenizers import PretrainedTransformerTokenizer | ||
from allennlp.data.token_indexers import PretrainedTransformerIndexer | ||
|
||
from allennlp_models.classification import BoolQDatasetReader | ||
from tests import FIXTURES_ROOT | ||
|
||
|
||
class TestBoolqReader: | ||
boolq_path = FIXTURES_ROOT / "classification" / "boolq.jsonl" | ||
|
||
def test_boolq_dataset_reader_default_setting(self): | ||
reader = BoolQDatasetReader() | ||
instances = reader.read(self.boolq_path) | ||
instances = ensure_list(instances) | ||
|
||
assert len(instances) == 5 | ||
|
||
fields = instances[0].fields | ||
assert [t.text for t in fields["tokens"].tokens][:5] == [ | ||
"Persian", | ||
"language", | ||
"--", | ||
"Persian", | ||
"(/ˈpɜːrʒən,", | ||
] | ||
assert fields["label"].label == 1 | ||
|
||
fields = instances[1].fields | ||
assert [t.text for t in fields["tokens"].tokens][:5] == [ | ||
"Epsom", | ||
"railway", | ||
"station", | ||
"--", | ||
"Epsom", | ||
] | ||
assert fields["label"].label == 0 | ||
|
||
def test_boolq_dataset_reader_roberta_setting(self): | ||
reader = BoolQDatasetReader( | ||
tokenizer=PretrainedTransformerTokenizer("roberta-base", add_special_tokens=False), | ||
token_indexers={"tokens": PretrainedTransformerIndexer("roberta-base")}, | ||
) | ||
instances = reader.read(self.boolq_path) | ||
instances = ensure_list(instances) | ||
|
||
assert len(instances) == 5 | ||
|
||
fields = instances[0].fields | ||
assert [t.text for t in fields["tokens"].tokens][:5] == [ | ||
"<s>", | ||
"Pers", | ||
"ian", | ||
"Ġlanguage", | ||
"Ġ--", | ||
] | ||
assert [t.text for t in fields["tokens"].tokens][-5:] == [ | ||
"Ġspeak", | ||
"Ġthe", | ||
"Ġsame", | ||
"Ġlanguage", | ||
"</s>", | ||
] | ||
assert fields["label"].label == 1 | ||
|
||
fields = instances[1].fields | ||
assert [t.text for t in fields["tokens"].tokens][:5] == [ | ||
"<s>", | ||
"E", | ||
"ps", | ||
"om", | ||
"Ġrailway", | ||
] | ||
assert [t.text for t in fields["tokens"].tokens][-5:] == [ | ||
"Ġe", | ||
"ps", | ||
"om", | ||
"Ġstation", | ||
"</s>", | ||
] | ||
assert fields["label"].label == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
local transformer_model = "roberta-large"; | ||
local transformer_dim = 1024; | ||
|
||
{ | ||
"dataset_reader":{ | ||
"type": "boolq", | ||
"token_indexers": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": transformer_model, | ||
} | ||
}, | ||
"tokenizer": { | ||
"type": "pretrained_transformer", | ||
"model_name": transformer_model, | ||
} | ||
}, | ||
"train_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/train.jsonl", | ||
"validation_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/val.jsonl", | ||
"test_data_path": "https://storage.googleapis.com/allennlp-public-data/BoolQ.zip!BoolQ/test.jsonl", | ||
"model": { | ||
"type": "basic_classifier", | ||
"text_field_embedder": { | ||
"token_embedders": { | ||
"tokens": { | ||
"type": "pretrained_transformer", | ||
"model_name": transformer_model, | ||
} | ||
} | ||
}, | ||
"seq2vec_encoder": { | ||
"type": "bert_pooler", | ||
"pretrained_model": transformer_model, | ||
"dropout": 0.1, | ||
}, | ||
"namespace": "tags", | ||
"num_labels": 2, | ||
}, | ||
"data_loader": { | ||
"batch_sampler": { | ||
"type": "bucket", | ||
"sorting_keys": ["tokens"], | ||
"batch_size" : 2 | ||
} | ||
}, | ||
"trainer": { | ||
"num_epochs": 10, | ||
"validation_metric": "+accuracy", | ||
"learning_rate_scheduler": { | ||
"type": "slanted_triangular", | ||
"num_epochs": 10, | ||
"num_steps_per_epoch": 3088, | ||
"cut_frac": 0.06 | ||
}, | ||
"optimizer": { | ||
"type": "huggingface_adamw", | ||
"lr": 1e-5, | ||
"weight_decay": 0.1, | ||
}, | ||
"num_gradient_accumulation_steps": 16, | ||
}, | ||
} |