diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c31a46f3..ba073366b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `superglue_record` to the rc readers for SuperGLUE's Reading Comprehension with Commonsense Reasoning task - Added some additional `__init__()` parameters to the `T5` model in `allennlp_models.generation` for customizing. beam search and other options. - Added a configuration file for fine-tuning `t5-11b` on CCN-DM (requires at least 8 GPUs). diff --git a/allennlp_models/rc/dataset_readers/record_reader.py b/allennlp_models/rc/dataset_readers/record_reader.py new file mode 100644 index 000000000..10ca889d4 --- /dev/null +++ b/allennlp_models/rc/dataset_readers/record_reader.py @@ -0,0 +1,543 @@ +""" +Dataset reader for SuperGLUE's Reading Comprehension with Commonsense Reasoning task (Zhang Et +al. 2018). + +Reader Implemented by Gabriel Orlanski +""" +import logging +from typing import Dict, List, Optional, Iterable, Union, Tuple, Any +from pathlib import Path +from allennlp.common.util import sanitize_wordpiece +from overrides import overrides +from allennlp.common.file_utils import cached_path +from allennlp.data.dataset_readers.dataset_reader import DatasetReader +from allennlp.data.fields import MetadataField, TextField, SpanField +from allennlp.data.instance import Instance +from allennlp_models.rc.dataset_readers.utils import char_span_to_token_span +from allennlp.data.token_indexers import PretrainedTransformerIndexer +from allennlp.data.tokenizers import Token, PretrainedTransformerTokenizer +import json + +logger = logging.getLogger(__name__) + +__all__ = ["RecordTaskReader"] + + +# TODO: Optimize this reader + + +@DatasetReader.register("superglue_record") +class RecordTaskReader(DatasetReader): + """ + Reader for Reading Comprehension with Commonsense Reasoning(ReCoRD) task from SuperGLUE. The + task is detailed in the paper ReCoRD: Bridging the Gap between Human and Machine Commonsense + Reading Comprehension (arxiv.org/pdf/1810.12885.pdf) by Zhang et al. Leaderboards and the + official evaluation script for the ReCoRD task can be found sheng-z.github.io/ReCoRD-explorer/. + + The reader reads a JSON file in the format from + sheng-z.github.io/ReCoRD-explorer/dataset-readme.txt + + + # Parameters + + tokenizer: `Tokenizer`, optional + The tokenizer class to use. Defaults to SpacyTokenizer + + token_indexers : `Dict[str, TokenIndexer]`, optional + We similarly use this for both the question and the passage. See :class:`TokenIndexer`. + Default is `{"tokens": SingleIdTokenIndexer()}`. + + passage_length_limit : `int`, optional (default=`None`) + If specified, we will cut the passage if the length of passage exceeds this limit. + + question_length_limit : `int`, optional (default=`None`) + If specified, we will cut the question if the length of question exceeds this limit. + + raise_errors: `bool`, optional (default=`False`) + If the reader should raise errors or just continue. + + kwargs: `Dict` + Keyword arguments to be passed to the DatasetReader parent class constructor. + + """ + + def __init__( + self, + transformer_model_name: str = "bert-base-cased", + length_limit: int = 384, + question_length_limit: int = 64, + stride: int = 128, + raise_errors: bool = False, + tokenizer_kwargs: Dict[str, Any] = None, + one_instance_per_query: bool = False, + max_instances: int = None, + **kwargs, + ) -> None: + """ + Initialize the RecordTaskReader. + """ + super(RecordTaskReader, self).__init__( + manual_distributed_sharding=True, max_instances=max_instances, **kwargs + ) + + self._kwargs = kwargs + + self._model_name = transformer_model_name + self._tokenizer_kwargs = tokenizer_kwargs or {} + # Save the values passed to __init__ to protected attributes + self._tokenizer = PretrainedTransformerTokenizer( + transformer_model_name, + add_special_tokens=False, + tokenizer_kwargs=tokenizer_kwargs, + ) + self._token_indexers = { + "tokens": PretrainedTransformerIndexer( + transformer_model_name, tokenizer_kwargs=tokenizer_kwargs + ) + } + self._length_limit = length_limit + self._query_len_limit = question_length_limit + self._stride = stride + self._raise_errors = raise_errors + self._cls_token = "@placeholder" + self._one_instance_per_query = one_instance_per_query + + def _to_params(self) -> Dict[str, Any]: + """ + Get the configuration dictionary for this class. + + # Returns + + `Dict[str, Any]` The config dict. + """ + return { + "type": "superglue_record", + "transformer_model_name": self._model_name, + "length_limit": self._length_limit, + "question_length_limit": self._query_len_limit, + "stride": self._stride, + "raise_errors": self._raise_errors, + "tokenizer_kwargs": self._tokenizer_kwargs, + "one_instance_per_query": self._one_instance_per_query, + "max_instances": self.max_instances, + **self._kwargs, + } + + @overrides + def _read(self, file_path: Union[Path, str]) -> Iterable[Instance]: + # IF `file_path` is a URL, redirect to the cache + file_path = cached_path(file_path) + + # Read the 'data' key from the dataset + logger.info(f"Reading '{file_path}'") + with open(file_path) as fp: + dataset = json.load(fp)["data"] + logger.info(f"Found {len(dataset)} examples from '{file_path}'") + + # Keep track of certain stats while reading the file + # examples_multiple_instance_count: The number of questions with more than + # one instance. Can happen because there is multiple queries for a + # single passage. + # passages_yielded: The total number of instances found/yielded. + examples_multiple_instance_count = 0 + examples_no_instance_count = 0 + passages_yielded = 0 + + # Iterate through every example from the ReCoRD data file. + for example in dataset: + + # Get the list of instances for the current example + instances_for_example = self.get_instances_from_example(example) + + # Keep track of number of instances for this specific example that + # have been yielded. Since it instances_for_example is a generator, we + # do not know its length. To address this, we create an counter int. + instance_count = 0 + + # Iterate through the instances and yield them. + for instance in instances_for_example: + yield instance + instance_count += 1 + + if instance_count == 0: + logger.warning(f"Example '{example['id']}' had no instances.") + examples_no_instance_count += 1 + + # Check if there was more than one instance for this example. If + # there was we increase examples_multiple_instance_count by 1. + # Otherwise we increase by 0. + examples_multiple_instance_count += 1 if instance_count > 1 else 0 + + passages_yielded += instance_count + + # Check to see if we are over the max_instances to yield. + if self.max_instances and passages_yielded > self.max_instances: + logger.info("Passed max instances") + break + + # Log pertinent information. + if passages_yielded: + logger.info( + f"{examples_multiple_instance_count}/{passages_yielded} " + f"({examples_multiple_instance_count / passages_yielded * 100:.2f}%) " + f"examples had more than one instance" + ) + logger.info( + f"{examples_no_instance_count}/{passages_yielded} " + f"({examples_no_instance_count / passages_yielded * 100:.2f}%) " + f"examples had no instances" + ) + else: + logger.warning(f"Could not find any instances in '{file_path}'") + + def get_instances_from_example( + self, example: Dict, always_add_answer_span: bool = False + ) -> Iterable[Instance]: + """ + Helper function to get instances from an example. + + Much of this comes from `transformer_squad.make_instances` + + # Parameters + + example: `Dict[str,Any]` + The example dict. + + # Returns: + + `Iterable[Instance]` The instances for each example + """ + # Get the passage dict from the example, it has text and + # entities + example_id: str = example["id"] + passage_dict: Dict = example["passage"] + passage_text: str = passage_dict["text"] + + # Tokenize the passage + tokenized_passage: List[Token] = self.tokenize_str(passage_text) + + # TODO: Determine what to do with entities. Superglue marks them + # explicitly as input (https://arxiv.org/pdf/1905.00537.pdf) + + # Get the queries from the example dict + queries: List = example["qas"] + logger.debug(f"{len(queries)} queries for example {example_id}") + + # Tokenize and get the context windows for each queries + for query in queries: + + # Create the additional metadata dict that will be passed w/ extra + # data for each query. We store the question & query ids, all + # answers, and other data following `transformer_qa`. + additional_metadata = { + "id": query["id"], + "example_id": example_id, + } + instances_yielded = 0 + # Tokenize, and truncate, the query based on the max set in + # `__init__` + tokenized_query = self.tokenize_str(query["query"])[: self._query_len_limit] + + # Calculate where the context needs to start and how many tokens we have + # for it. This is due to the limit on the number of tokens that a + # transformer can use because they have quadratic memory usage. But if + # you are reading this code, you probably know that. + space_for_context = ( + self._length_limit + - len(list(tokenized_query)) + # Used getattr so I can test without having to load a + # transformer model. + - len(getattr(self._tokenizer, "sequence_pair_start_tokens", [])) + - len(getattr(self._tokenizer, "sequence_pair_mid_tokens", [])) + - len(getattr(self._tokenizer, "sequence_pair_end_tokens", [])) + ) + + # Check if answers exist for this query. We assume that there are no + # answers for this query, and set the start and end index for the + # answer span to -1. + answers = query.get("answers", []) + if not answers: + logger.warning(f"Skipping {query['id']}, no answers") + continue + + # Create the arguments needed for `char_span_to_token_span` + token_offsets = [ + (t.idx, t.idx + len(sanitize_wordpiece(t.text))) if t.idx is not None else None + for t in tokenized_passage + ] + + # Get the token offsets for the answers for this current passage. + answer_token_start, answer_token_end = (-1, -1) + for answer in answers: + + # Try to find the offsets. + offsets, _ = char_span_to_token_span( + token_offsets, (answer["start"], answer["end"]) + ) + + # If offsets for an answer were found, it means the answer is in + # the passage, and thus we can stop looking. + if offsets != (-1, -1): + answer_token_start, answer_token_end = offsets + break + + # Go through the context and find the window that has the answer in it. + stride_start = 0 + + while True: + tokenized_context_window = tokenized_passage[stride_start:] + tokenized_context_window = tokenized_context_window[:space_for_context] + + # Get the token offsets w.r.t the current window. + window_token_answer_span = ( + answer_token_start - stride_start, + answer_token_end - stride_start, + ) + if any( + i < 0 or i >= len(tokenized_context_window) for i in window_token_answer_span + ): + # The answer is not contained in the window. + window_token_answer_span = None + + if ( + # not self.skip_impossible_questions + window_token_answer_span + is not None + ): + # The answer WAS found in the context window, and thus we + # can make an instance for the answer. + instance = self.text_to_instance( + query["query"], + tokenized_query, + passage_text, + tokenized_context_window, + answers=[answer["text"] for answer in answers], + token_answer_span=window_token_answer_span, + additional_metadata=additional_metadata, + always_add_answer_span=always_add_answer_span, + ) + yield instance + instances_yielded += 1 + + if instances_yielded == 1 and self._one_instance_per_query: + break + + stride_start += space_for_context + + # If we have reached the end of the passage, stop. + if stride_start >= len(tokenized_passage): + break + + # I am not sure what this does...but it is here? + stride_start -= self._stride + + def tokenize_slice(self, text: str, start: int = None, end: int = None) -> Iterable[Token]: + """ + Get + tokenize a span from a source text. + + *Originally from the `transformer_squad.py`* + + # Parameters + + text: `str` + The text to draw from. + start: `int` + The start index for the span. + end: `int` + The end index for the span. Assumed that this is inclusive. + + # Returns + + `Iterable[Token]` List of tokens for the retrieved span. + """ + start = start or 0 + end = end or len(text) + text_to_tokenize = text[start:end] + + # Check if this is the start of the text. If the start is >= 0, check + # for a preceding space. If it exists, then we need to tokenize a + # special way because of a bug with RoBERTa tokenizer. + if start - 1 >= 0 and text[start - 1].isspace(): + + # Per the original tokenize_slice function, you need to add a + # garbage token before the actual text you want to tokenize so that + # the tokenizer does not add a beginning of sentence token. + prefix = "a " + + # Tokenize the combined prefix and text + wordpieces = self._tokenizer.tokenize(prefix + text_to_tokenize) + + # Go through each wordpiece in the tokenized wordpieces. + for wordpiece in wordpieces: + + # Because we added the garbage prefix before tokenize, we need + # to adjust the idx such that it accounts for this. Therefore we + # subtract the length of the prefix from each token's idx. + if wordpiece.idx is not None: + wordpiece.idx -= len(prefix) + + # We do not want the garbage token, so we return all but the first + # token. + return wordpieces[1:] + else: + + # Do not need any sort of prefix, so just return all of the tokens. + return self._tokenizer.tokenize(text_to_tokenize) + + def tokenize_str(self, text: str) -> List[Token]: + """ + Helper method to tokenize a string. + + Adapted from the `transformer_squad.make_instances` + + # Parameters + text: `str` + The string to tokenize. + + # Returns + + `Iterable[Tokens]` The resulting tokens. + + """ + # We need to keep track of the current token index so that we can update + # the results from self.tokenize_slice such that they reflect their + # actual position in the string rather than their position in the slice + # passed to tokenize_slice. Also used to construct the slice. + token_index = 0 + + # Create the output list (can be any iterable) that will store the + # tokens we found. + tokenized_str = [] + + # Helper function to update the `idx` and add every wordpiece in the + # `tokenized_slice` to the `tokenized_str`. + def add_wordpieces(tokenized_slice: Iterable[Token]) -> None: + for wordpiece in tokenized_slice: + if wordpiece.idx is not None: + wordpiece.idx += token_index + tokenized_str.append(wordpiece) + + # Iterate through every character and their respective index in the text + # to create the slices to tokenize. + for i, c in enumerate(text): + + # Check if the current character is a space. If it is, we tokenize + # the slice of `text` from `token_index` to `i`. + if c.isspace(): + add_wordpieces(self.tokenize_slice(text, token_index, i)) + token_index = i + 1 + + # Add the end slice that is not collected by the for loop. + add_wordpieces(self.tokenize_slice(text, token_index, len(text))) + + return tokenized_str + + @staticmethod + def get_spans_from_text(text: str, spans: List[Tuple[int, int]]) -> List[str]: + """ + Helper function to get a span from a string + + # Parameter + + text: `str` + The source string + spans: `List[Tuple[int,int]]` + List of start and end indices for spans. + + Assumes that the end index is inclusive. Therefore, for start + index `i` and end index `j`, retrieves the span at `text[i:j+1]`. + + # Returns + + `List[str]` The extracted string from text. + """ + return [text[start : end + 1] for start, end in spans] + + @overrides + def text_to_instance( + self, + query: str, + tokenized_query: List[Token], + passage: str, + tokenized_passage: List[Token], + answers: List[str], + token_answer_span: Optional[Tuple[int, int]] = None, + additional_metadata: Optional[Dict[str, Any]] = None, + always_add_answer_span: Optional[bool] = False, + ) -> Instance: + """ + A lot of this comes directly from the `transformer_squad.text_to_instance` + """ + fields = {} + + # Create the query field from the tokenized question and context. Use + # `self._tokenizer.add_special_tokens` function to add the necessary + # special tokens to the query. + query_field = TextField( + self._tokenizer.add_special_tokens( + # The `add_special_tokens` function automatically adds in the + # separation token to mark the separation between the two lists of + # tokens. Therefore, we can create the query field WITH context + # through passing them both as arguments. + tokenized_query, + tokenized_passage, + ), + self._token_indexers, + ) + + # Add the query field to the fields dict that will be outputted as an + # instance. Do it here rather than assign above so that we can use + # attributes from `query_field` rather than continuously indexing + # `fields`. + fields["question_with_context"] = query_field + + # Calculate the index that marks the start of the context. + start_of_context = ( + +len(tokenized_query) + # Used getattr so I can test without having to load a + # transformer model. + + len(getattr(self._tokenizer, "sequence_pair_start_tokens", [])) + + len(getattr(self._tokenizer, "sequence_pair_mid_tokens", [])) + ) + + # make the answer span + if token_answer_span is not None: + assert all(i >= 0 for i in token_answer_span) + assert token_answer_span[0] <= token_answer_span[1] + + fields["answer_span"] = SpanField( + token_answer_span[0] + start_of_context, + token_answer_span[1] + start_of_context, + query_field, + ) + # make the context span, i.e., the span of text from which possible + # answers should be drawn + fields["context_span"] = SpanField( + start_of_context, start_of_context + len(tokenized_passage) - 1, query_field + ) + + # make the metadata + metadata = { + "question": query, + "question_tokens": tokenized_query, + "context": passage, + "context_tokens": tokenized_passage, + "answers": answers or [], + } + if additional_metadata is not None: + metadata.update(additional_metadata) + fields["metadata"] = MetadataField(metadata) + + return Instance(fields) + + def _find_cls_index(self, tokens: List[Token]) -> int: + """ + From transformer_squad + Args: + self: + tokens: + + Returns: + + """ + return next(i for i, t in enumerate(tokens) if t.text == self._cls_token) diff --git a/test_fixtures/rc/record.json b/test_fixtures/rc/record.json new file mode 100644 index 000000000..8f2b8ab26 --- /dev/null +++ b/test_fixtures/rc/record.json @@ -0,0 +1,193 @@ +{ + "version": "1.0", + "data": [ + { + "id": "483577c837cdd4df5bbbbd5cfa3a77f6fea3519e", + "source": "Daily mail", + "passage": { + "text": "Tracy Morgan hasn't appeared on stage since the devastating New Jersey crash that nearly ended his life last summer, but all that will change this fall when he returns to host Saturday Night Live. NBC announced on Twitter Monday that Morgan, an SNL alum with seven seasons as a cast member under his belt, will headline the third episode of Season 41 airing October 17. For Morgan, 46, it will be a second time hosting the long-running variety show, the first since the June 2014 pileup on the New Jersey Turnpike that killed his friend and mentor James 'Jimmy Mack' McNair.\n@highlight\nMorgan, 46, will host third episode of season 41 of SNL airing October 17\n@highlight\nHe tweeted to his fans: 'Stoked to be going home...#SNL'\n@highlight\nFor the SNL alum who had spent seven years as cast member, it will be a second time hosting the show\n@highlight\nMorgan has been sidelined by severe head trauma suffered in deadly June 2014 crash on New Jersey Turnpike that killed his friend\n@highlight\nFirst episode of new SNL season will be hosted by Miley Cyrus, followed by Amy Schumer", + "entities": [ + { + "start": 0, + "end": 11 + }, + { + "start": 60, + "end": 69 + }, + { + "start": 185, + "end": 194 + }, + { + "start": 197, + "end": 199 + }, + { + "start": 214, + "end": 220 + }, + { + "start": 234, + "end": 239 + }, + { + "start": 245, + "end": 247 + }, + { + "start": 341, + "end": 349 + }, + { + "start": 374, + "end": 379 + }, + { + "start": 494, + "end": 512 + }, + { + "start": 548, + "end": 552 + }, + { + "start": 555, + "end": 564 + }, + { + "start": 567, + "end": 572 + }, + { + "start": 586, + "end": 591 + }, + { + "start": 638, + "end": 640 + }, + { + "start": 747, + "end": 749 + }, + { + "start": 851, + "end": 856 + }, + { + "start": 937, + "end": 955 + }, + { + "start": 1012, + "end": 1014 + }, + { + "start": 1041, + "end": 1051 + }, + { + "start": 1066, + "end": 1076 + } + ] + }, + "qas": [ + { + "id": "483577c837cdd4df5bbbbd5cfa3a77f6fea3519e-18db09b9e470ab13e21de901d213aff3db85d1e5-132", + "query": "On October 10, acclaimed comedian and star of the summer box office hit Trainwreck Amy Schumer will make her SNL debut, followed by @placeholder a week later.", + "answers": [ + { + "start": 0, + "end": 11, + "text": "Tracy Morgan" + }, + { + "start": 234, + "end": 239, + "text": "Morgan" + }, + { + "start": 374, + "end": 379, + "text": "Morgan" + }, + { + "start": 586, + "end": 591, + "text": "Morgan" + }, + { + "start": 851, + "end": 856, + "text": "Morgan" + } + ] + } + ] + }, + { + "id": "c1037ea3d376ca9e0371478795c24aaaa36f76be", + "source": "Daily mail", + "passage": { + "text": "For four years we have waited expectantly for the pitter patter of tiny paws. Soon, that wait could finally be over. Tian Tian, the UK's only female giant panda, has conceived and could give birth to a cub as early as August. However Edinburgh Zoo, where the pandas live, have warned people 'not to get too excited' as the process is 'extremely complex'. Moreover, on the two previous occasions keepers inseminated Tian Tian - whose name means 'Sweetie' - she has failed to produce a panda cub. She was artificially inseminated again in March this year, but keepers at the zoo say implantation - when a fertilised egg attaches to the uterus - has not yet occurred.\n@highlight\nTian Tian has conceived and could give birth to a cub as early as August\n@highlight\nShe has been inseminated twice before but so far failed to produce a cub\n@highlight\nTian Tian and Yang Guang arrived in 2011 from China to great fanfare\n@highlight\nOn loan at \u00a3600k a year, became first giant pandas to live in UK for 17 years", + "entities": [ + { + "start": 117, + "end": 125 + }, + { + "start": 132, + "end": 133 + }, + { + "start": 234, + "end": 246 + }, + { + "start": 415, + "end": 423 + }, + { + "start": 445, + "end": 451 + }, + { + "start": 676, + "end": 684 + }, + { + "start": 844, + "end": 852 + }, + { + "start": 858, + "end": 867 + }, + { + "start": 890, + "end": 894 + }, + { + "start": 986, + "end": 987 + } + ] + }, + "qas": [ + { + "id": "c1037ea3d376ca9e0371478795c24aaaa36f76be-f25ae34afd9730c53b1f4cb9546a09fc1e66de82-57", + "query": "Under the terms of the agreement any cubs will return to @placeholder at the age of two, the age at which they would normally leave their mother in the wild.", + "answers": [ + { + "start": 890, + "end": 894, + "text": "China" + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/rc/dataset_readers/record_reader_test.py b/tests/rc/dataset_readers/record_reader_test.py new file mode 100644 index 000000000..404fd2860 --- /dev/null +++ b/tests/rc/dataset_readers/record_reader_test.py @@ -0,0 +1,412 @@ +import pytest +from allennlp.data.tokenizers import WhitespaceTokenizer +from allennlp.data.token_indexers import SingleIdTokenIndexer +from tests import FIXTURES_ROOT +import re +from typing import List + +from allennlp_models.rc.dataset_readers.record_reader import RecordTaskReader + +""" +Tests for the ReCoRD reader from SuperGLUE +""" + + +# TODO: Add full integration tests + + +class TestRecordReader: + @pytest.fixture + def reader(self): + yield RecordTaskReader(length_limit=256) + + @pytest.fixture + def small_reader(self, reader): + # Some tests need the transformer tokenizer, but not the long lengths. + # Nice Middle ground. + reader._length_limit = 24 + reader._query_len_limit = 8 + reader._stride = 4 + return reader + + @pytest.fixture + def whitespace_reader(self, small_reader): + # Set the tokenizer to whitespace tokenization for ease of use and + # testing. Easier to test than using a transformer tokenizer. + small_reader._tokenizer = WhitespaceTokenizer() + small_reader._token_indexers = SingleIdTokenIndexer() + yield small_reader + + @pytest.fixture + def passage(self): + return ( + "Reading Comprehension with Commonsense Reasoning Dataset ( ReCoRD ) " + "is a large-scale reading comprehension dataset which requires " + "commonsense reasoning" + ) + + @pytest.fixture + def record_name_passage(self, passage): + """ + From the passage above, this is the snippet that contains the phrase + "Reading Comprehension with Commonsense Reasoning Dataset". The returned + object is a tuple with (start: int, end: int, text: str). + """ + start = 0 + end = 56 + yield start, end, passage[start:end] + + @pytest.fixture + def tokenized_passage(self, passage): + tokenizer = WhitespaceTokenizer() + return tokenizer.tokenize(passage) + + @pytest.fixture + def answers(self): + return [ + {"start": 58, "end": 64, "text": "ReCoRD"}, + {"start": 128, "end": 149, "text": "commonsense reasoning"}, + {"start": 256, "end": 512, "text": "Should not exist"}, + ] + + @pytest.fixture + def example_basic(self): + + return { + "id": "dummy1", + "source": "ReCoRD docs", + "passage": { + "text": "ReCoRD contains 120,000+ queries from 70,000+ news articles. Each " + "query has been validated by crowdworkers. Unlike existing reading " + "comprehension datasets, ReCoRD contains a large portion of queries " + "requiring commonsense reasoning, thus presenting a good challenge " + "for future research to bridge the gap between human and machine " + "commonsense reading comprehension .", + "entities": [ + {"start": 0, "end": 6}, + {"start": 156, "end": 162}, + {"start": 250, "end": 264}, + ], + }, + "qas": [ + { + "id": "dummyA1", + "query": "@placeholder is a dataset", + "answers": [ + {"start": 0, "end": 6, "text": "ReCoRD"}, + {"start": 156, "end": 162, "text": "ReCoRD"}, + ], + }, + { + "id": "dummayA2", + "query": "ReCoRD presents a @placeholder with the commonsense reading " + "comprehension task", + "answers": [ + {"start": 250, "end": 264, "text": "good challenge"}, + ], + }, + ], + } + + @pytest.fixture + def curiosity_example(self): + """ + Bug where most examples did not return any instances, so doing + regression testing on this real example that did not return anything. + """ + return { + "id": "d978b083f3f97a2ab09771c72398cfbac094f818", + "source": "Daily mail", + "passage": { + "text": "By Sarah Griffiths PUBLISHED: 12:30 EST, 10 July 2013 | UPDATED: " + "12:37 EST, 10 July 2013 Nasa's next Mars rover has been given a " + "mission to find signs of past life and to collect and store rock " + "from the the red planet that will one day be sent back to Earth. It " + "will demonstrate technology for a human exploration of the planet " + "and look for signs of life. The space agency has revealed what the " + "rover, known as Mars 2020, will look like. Scroll down for video... " + "Nasa's next Mars rover (plans pictured) has been given a mission to " + "find signs of past life and to collect and store rock from the the " + "red planet that will one day be sent back to Earth. Mars 2020 will " + "also demonstrate technology for a human exploration of the " + "planet\n@highlight\nMars 2020 will collect up to 31 rock and soil " + "samples from the red planet and will look for signs of " + "extraterrestrial life\n@highlight\nThe new rover will use the same " + "landing system as Curiosity and share its frame, which has saved " + "Nasa $1 billion\n@highlight\nThe mission will bring the sapec agency " + "a step closer to meeting President Obama's challenge to send humans " + "to Mars in the next decade", + "entities": [ + {"start": 3, "end": 17}, + {"start": 89, "end": 92}, + {"start": 101, "end": 104}, + {"start": 252, "end": 256}, + {"start": 411, "end": 414}, + {"start": 463, "end": 466}, + {"start": 475, "end": 478}, + {"start": 643, "end": 647}, + {"start": 650, "end": 653}, + {"start": 742, "end": 745}, + {"start": 926, "end": 934}, + {"start": 973, "end": 976}, + {"start": 1075, "end": 1079}, + {"start": 1111, "end": 1114}, + ], + }, + "qas": [ + { + "id": "d978b083f3f97a2ab09771c72398cfbac094f818" + "-04b6e904611f0d706521db167a05a11bf693e40e-61", + "query": "The 2020 mission plans on building on the accomplishments of " + "@placeholder and other Mars missions.", + "answers": [{"start": 926, "end": 934, "text": "Curiosity"}], + } + ], + } + + @pytest.fixture + def skyfall_example(self): + """ + Another example that was not returning instances + """ + return { + "id": "6f1ca8baf24bf9e5fc8e33b4b3b04bd54370b25f", + "source": "Daily mail", + "passage": { + "text": "They're both famous singers who have lent their powerful voices to " + "James " + "Bond films. And it seems the Oscars' stage wasn't big enough to " + "accommodate larger-and-life divas Adele and Dame Shirley Bassey, " + "at least at the same time. Instead of the two songstresses dueting or " + "sharing the stage, each performed her theme song separately during " + "Sunday night's ceremony. Scroll down for video Battle of the divas: " + "Adele and Dame Shirley Bassey separately sang James Bond theme songs " + "during Sunday night's Oscar ceremony Shirley performed first, " + "singing Goldfinger nearly 50 years since she first recorded the song " + "for " + "the 1964 Bond film of the same name.\n@highlight\nAdele awarded Oscar " + "for Best Original Score for Skyfall", + "entities": [ + {"start": 67, "end": 76}, + {"start": 102, "end": 107}, + {"start": 171, "end": 175}, + {"start": 181, "end": 199}, + {"start": 407, "end": 411}, + {"start": 417, "end": 435}, + {"start": 453, "end": 462}, + {"start": 498, "end": 502}, + {"start": 513, "end": 519}, + {"start": 546, "end": 555}, + {"start": 620, "end": 623}, + {"start": 659, "end": 663}, + {"start": 673, "end": 701}, + {"start": 707, "end": 713}, + ], + }, + "qas": [ + { + "id": "6f1ca8baf24bf9e5fc8e33b4b3b04bd54370b25f" + "-98823006424cc595642b5ae5fa1b533bbd215a56-105", + "query": "The full works: Adele was accompanied by an orchestra, choir and " + "light display during her performance of @placeholder", + "answers": [{"start": 707, "end": 713, "text": "Skyfall"}], + } + ], + } + + @staticmethod + def _token_list_to_str(tokens) -> List[str]: + return list(map(str, tokens)) + + ##################################################################### + # Unittests # + ##################################################################### + def test_tokenize_slice_bos(self, whitespace_reader, passage, record_name_passage): + """ + Test `tokenize_slice` with a string that is at the beginning of the + text. This means that `start`=0. + """ + result = list( + whitespace_reader.tokenize_slice( + passage, record_name_passage[0], record_name_passage[1] + ) + ) + + assert len(result) == 6 + + expected = ["Reading", "Comprehension", "with", "Commonsense", "Reasoning", "Dataset"] + for i in range(len(result)): + assert str(result[i]) == expected[i] + + def test_tokenize_slice_prefix(self, whitespace_reader, passage, record_name_passage): + result = list( + whitespace_reader.tokenize_slice( + passage, record_name_passage[0] + 8, record_name_passage[1] + ) + ) + + expected = ["Comprehension", "with", "Commonsense", "Reasoning", "Dataset"] + assert len(result) == len(expected) + + for i in range(len(result)): + assert str(result[i]) == expected[i] + + def test_tokenize_str(self, whitespace_reader, record_name_passage): + result = list(whitespace_reader.tokenize_str(record_name_passage[-1])) + expected = ["Reading", "Comprehension", "with", "Commonsense", "Reasoning", "Dataset"] + assert len(result) == len(expected) + + for i in range(len(result)): + assert str(result[i]) == expected[i] + + def test_get_instances_from_example(self, small_reader, tokenized_passage, example_basic): + # TODO: Make better + result = list(small_reader.get_instances_from_example(example_basic)) + + result_text = " ".join([t.text for t in result[0]["question_with_context"].tokens]) + assert len(result) == 2 + assert len(result[0]["question_with_context"].tokens) == small_reader._length_limit + assert "@" in result_text + assert "place" in result_text + assert "holder" in result_text + + result_text = " ".join([t.text for t in result[1]["question_with_context"].tokens]) + assert len(result[1]["question_with_context"]) == small_reader._length_limit + assert "@" in result_text + assert "place" in result_text + assert "holder" not in result_text + + def test_get_instances_from_example_fields( + self, small_reader, tokenized_passage, example_basic + ): + results = list(small_reader.get_instances_from_example(example_basic)) + expected_keys = [ + "question_with_context", + "context_span", + # "cls_index", + "answer_span", + "metadata", + ] + for i in range(len(results)): + assert len(results[i].fields) == len( + expected_keys + ), f"results[{i}] has incorrect number of fields" + for k in expected_keys: + assert k in results[i].fields, f"results[{i}] is missing {k}" + + ##################################################################### + # Regression Test # + ##################################################################### + + def test_get_instances_from_example_curiosity(self, reader, curiosity_example): + tokenized_answer = " ".join(map(str, reader.tokenize_str("Curiosity"))) + results = list(reader.get_instances_from_example(curiosity_example)) + assert len(results) == 2 + assert tokenized_answer in " ".join(map(str, results[0]["question_with_context"].tokens)) + assert tokenized_answer in " ".join(map(str, results[1]["question_with_context"].tokens)) + + # TODO: Make this its own test. + # Kind of forced this extra test in here because I added it while + # solving this bug, so just left it instead of creating another + # unittest. + reader._one_instance_per_query = True + results = list(reader.get_instances_from_example(curiosity_example)) + assert len(results) == 1 + assert tokenized_answer in " ".join(map(str, results[0]["question_with_context"].tokens)) + + def test_get_instances_from_example_skyfall(self, reader, skyfall_example): + """ + This will fail for the time being. + """ + tokenized_answer = self._token_list_to_str(reader.tokenize_str("Skyfall")) + + results = list(reader.get_instances_from_example(skyfall_example)) + + assert len(results) == 1 + assert ( + self._token_list_to_str(results[0]["question_with_context"][-3:-1]) == tokenized_answer + ) + + def test_tokenize_str_roberta(self): + reader = RecordTaskReader(transformer_model_name="roberta-base", length_limit=256) + result = reader.tokenize_str("The new rover.") + result = list(map(lambda t: t.text[1:], result)) + assert len(result) == 4 + assert result == ["he", "new", "rover", ""] + + def test_read(self, small_reader): + instances = list(small_reader.read(FIXTURES_ROOT.joinpath("rc/record.json"))) + assert len(instances) == 2 + + tokens = self._token_list_to_str(instances[0].fields["question_with_context"]) + assert tokens == [ + "[CLS]", + "On", + "October", + "10", + ",", + "acclaimed", + "comedian", + "and", + "star", + "[SEP]", + "Tracy", + "Morgan", + "hasn", + "'", + "t", + "appeared", + "on", + "stage", + "since", + "the", + "devastating", + "New", + "Jersey", + "[SEP]", + ] + answer_span = instances[0].fields["answer_span"] + assert tokens[answer_span.span_start : answer_span.span_end + 1] == ["Tracy", "Morgan"] + + tokens = self._token_list_to_str(instances[1].fields["question_with_context"]) + assert tokens == [ + "[CLS]", + "Under", + "the", + "terms", + "of", + "the", + "agreement", + "any", + "cu", + "[SEP]", + "arrived", + "in", + "2011", + "from", + "China", + "to", + "great", + "fan", + "##fare", + "@", + "highlight", + "On", + "loan", + "[SEP]", + ] + answer_span = instances[1].fields["answer_span"] + assert tokens[answer_span.span_start : answer_span.span_end + 1] == ["China"] + + def test_to_params(self, small_reader): + assert small_reader.to_params() == { + "type": "superglue_record", + "transformer_model_name": "bert-base-cased", + "length_limit": 24, + "question_length_limit": 8, + "stride": 4, + "raise_errors": False, + "tokenizer_kwargs": {}, + "one_instance_per_query": False, + "max_instances": None, + }