This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 176
/
reader_utils.py
465 lines (426 loc) · 21.8 KB
/
reader_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
"""
Utilities for reading comprehension dataset readers.
"""
from collections import Counter, defaultdict
import logging
import string
from typing import Any, Dict, List, Tuple, Optional
from allennlp.data.fields import (
Field,
TextField,
IndexField,
MetadataField,
LabelField,
ListField,
SequenceLabelField,
)
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer
from allennlp.data.tokenizers import Token
logger = logging.getLogger(__name__)
# These are tokens and characters that are stripped by the standard SQuAD and TriviaQA evaluation
# scripts.
IGNORED_TOKENS = {"a", "an", "the"}
STRIPPED_CHARACTERS = string.punctuation + "".join(["‘", "’", "´", "`", "_"])
def normalize_text(text: str) -> str:
"""
Performs a normalization that is very similar to that done by the normalization functions in
SQuAD and TriviaQA.
This involves splitting and rejoining the text, and could be a somewhat expensive operation.
"""
return " ".join(
[
token
for token in text.lower().strip(STRIPPED_CHARACTERS).split()
if token not in IGNORED_TOKENS
]
)
def char_span_to_token_span(
token_offsets: List[Optional[Tuple[int, int]]], character_span: Tuple[int, int]
) -> Tuple[Tuple[int, int], bool]:
"""
Converts a character span from a passage into the corresponding token span in the tokenized
version of the passage. If you pass in a character span that does not correspond to complete
tokens in the tokenized version, we'll do our best, but the behavior is officially undefined.
We return an error flag in this case, and have some debug logging so you can figure out the
cause of this issue (in SQuAD, these are mostly either tokenization problems or annotation
problems; there's a fair amount of both).
The basic outline of this method is to find the token span that has the same offsets as the
input character span. If the tokenizer tokenized the passage correctly and has matching
offsets, this is easy. We try to be a little smart about cases where they don't match exactly,
but mostly just find the closest thing we can.
The returned ``(begin, end)`` indices are `inclusive` for both ``begin`` and ``end``.
So, for example, ``(2, 2)`` is the one word span beginning at token index 2, ``(3, 4)`` is the
two-word span beginning at token index 3, and so on.
Returns
-------
token_span : ``Tuple[int, int]``
`Inclusive` span start and end token indices that match as closely as possible to the input
character spans.
error : ``bool``
Whether the token spans match the input character spans exactly. If this is ``False``, it
means there was an error in either the tokenization or the annotated character span.
"""
# We have token offsets into the passage from the tokenizer; we _should_ be able to just find
# the tokens that have the same offsets as our span.
error = False
start_index = 0
while start_index < len(token_offsets) and (
token_offsets[start_index] is None or token_offsets[start_index][0] < character_span[0]
):
start_index += 1
if start_index >= len(token_offsets):
raise ValueError(f"Character span %r outside the range of the given tokens.")
# start_index should now be pointing at the span start index.
if token_offsets[start_index][0] > character_span[0]:
if start_index <= 0:
raise ValueError(f"Character span %r outside the range of the given tokens.")
# In this case, a tokenization or labeling issue made us go too far - the character span we're looking for
# actually starts in the previous token. We'll back up one. Note that this might have us starting at a None
# token.
logger.debug("Bad labelling or tokenization - start offset doesn't match")
start_index -= 1
if token_offsets[start_index] is None or token_offsets[start_index][0] != character_span[0]:
error = True
end_index = start_index
while end_index < len(token_offsets) and (
token_offsets[end_index] is None or token_offsets[end_index][1] < character_span[1]
):
end_index += 1
if end_index >= len(token_offsets):
raise ValueError(f"Character span %r outside the range of the given tokens.")
if end_index == start_index and token_offsets[end_index][1] > character_span[1]:
# Looks like there was a token that should have been split, like "1854-1855", where the
# answer is "1854". We can't do much in this case, except keep the answer as the whole
# token.
logger.debug("Bad tokenization - end offset doesn't match")
elif token_offsets[end_index][1] > character_span[1]:
# This is a case where the given answer span is more than one token, and the last token is
# cut off for some reason, like "split with Luckett and Rober", when the original passage
# said "split with Luckett and Roberson". In this case, we'll just keep the end index
# where it is, and assume the intent was to mark the whole token.
logger.debug("Bad labelling or tokenization - end offset doesn't match")
if token_offsets[end_index][1] != character_span[1]:
error = True
return (start_index, end_index), error
def find_valid_answer_spans(
passage_tokens: List[Token], answer_texts: List[str]
) -> List[Tuple[int, int]]:
"""
Finds a list of token spans in ``passage_tokens`` that match the given ``answer_texts``. This
tries to find all spans that would evaluate to correct given the SQuAD and TriviaQA official
evaluation scripts, which do some normalization of the input text.
Note that this could return duplicate spans! The caller is expected to be able to handle
possible duplicates (as already happens in the SQuAD dev set, for instance).
"""
normalized_tokens = [token.text.lower().strip(STRIPPED_CHARACTERS) for token in passage_tokens]
# Because there could be many `answer_texts`, we'll do the most expensive pre-processing
# step once. This gives us a map from tokens to the position in the passage they appear.
word_positions: Dict[str, List[int]] = defaultdict(list)
for i, token in enumerate(normalized_tokens):
word_positions[token].append(i)
spans = []
for answer_text in answer_texts:
# For each answer, we'll first find all valid start positions in the passage. Then
# we'll grow each span to the same length as the number of answer tokens, and see if we
# have a match. We're a little tricky as we grow the span, skipping words that are
# already pruned from the normalized answer text, and stopping early if we don't match.
answer_tokens = answer_text.lower().strip(STRIPPED_CHARACTERS).split()
num_answer_tokens = len(answer_tokens)
for span_start in word_positions[answer_tokens[0]]:
span_end = span_start # span_end is _inclusive_
answer_index = 1
while answer_index < num_answer_tokens and span_end + 1 < len(normalized_tokens):
token = normalized_tokens[span_end + 1]
if answer_tokens[answer_index] == token:
answer_index += 1
span_end += 1
elif token in IGNORED_TOKENS:
span_end += 1
else:
break
if num_answer_tokens == answer_index:
spans.append((span_start, span_end))
return spans
def make_reading_comprehension_instance(
question_tokens: List[Token],
passage_tokens: List[Token],
token_indexers: Dict[str, TokenIndexer],
passage_text: str,
token_spans: List[Tuple[int, int]] = None,
answer_texts: List[str] = None,
additional_metadata: Dict[str, Any] = None,
) -> Instance:
"""
Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use
in a reading comprehension model.
Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both
``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts``
and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end``
fields, which are both ``IndexFields``.
Parameters
----------
question_tokens : ``List[Token]``
An already-tokenized question.
passage_tokens : ``List[Token]``
An already-tokenized passage that contains the answer to the given question.
token_indexers : ``Dict[str, TokenIndexer]``
Determines how the question and passage ``TextFields`` will be converted into tensors that
get input to a model. See :class:`TokenIndexer`.
passage_text : ``str``
The original passage text. We need this so that we can recover the actual span from the
original passage that the model predicts as the answer to the question. This is used in
official evaluation scripts.
token_spans : ``List[Tuple[int, int]]``, optional
Indices into ``passage_tokens`` to use as the answer to the question for training. This is
a list because there might be several possible correct answer spans in the passage.
Currently, we just select the most frequent span in this list (i.e., SQuAD has multiple
annotations on the dev set; this will select the span that the most annotators gave as
correct).
answer_texts : ``List[str]``, optional
All valid answer strings for the given question. In SQuAD, e.g., the training set has
exactly one answer per question, but the dev and test sets have several. TriviaQA has many
possible answers, which are the aliases for the known correct entity. This is put into the
metadata for use with official evaluation scripts, but not used anywhere else.
additional_metadata : ``Dict[str, Any]``, optional
The constructed ``metadata`` field will by default contain ``original_passage``,
``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If
you want any other metadata to be associated with each instance, you can pass that in here.
This dictionary will get added to the ``metadata`` dictionary we already construct.
"""
additional_metadata = additional_metadata or {}
fields: Dict[str, Field] = {}
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
# This is separate so we can reference it later with a known type.
passage_field = TextField(passage_tokens, token_indexers)
fields["passage"] = passage_field
fields["question"] = TextField(question_tokens, token_indexers)
metadata = {
"original_passage": passage_text,
"token_offsets": passage_offsets,
"question_tokens": [token.text for token in question_tokens],
"passage_tokens": [token.text for token in passage_tokens],
}
if answer_texts:
metadata["answer_texts"] = answer_texts
if token_spans:
# There may be multiple answer annotations, so we pick the one that occurs the most. This
# only matters on the SQuAD dev set, and it means our computed metrics ("start_acc",
# "end_acc", and "span_acc") aren't quite the same as the official metrics, which look at
# all of the annotations. This is why we have a separate official SQuAD metric calculation
# (the "em" and "f1" metrics use the official script).
candidate_answers: Counter = Counter()
for span_start, span_end in token_spans:
candidate_answers[(span_start, span_end)] += 1
span_start, span_end = candidate_answers.most_common(1)[0][0]
fields["span_start"] = IndexField(span_start, passage_field)
fields["span_end"] = IndexField(span_end, passage_field)
metadata.update(additional_metadata)
fields["metadata"] = MetadataField(metadata)
return Instance(fields)
def make_reading_comprehension_instance_quac(
question_list_tokens: List[List[Token]],
passage_tokens: List[Token],
token_indexers: Dict[str, TokenIndexer],
passage_text: str,
token_span_lists: List[List[Tuple[int, int]]] = None,
yesno_list: List[int] = None,
followup_list: List[int] = None,
additional_metadata: Dict[str, Any] = None,
num_context_answers: int = 0,
) -> Instance:
"""
Converts a question, a passage, and an optional answer (or answers) to an ``Instance`` for use
in a reading comprehension model.
Creates an ``Instance`` with at least these fields: ``question`` and ``passage``, both
``TextFields``; and ``metadata``, a ``MetadataField``. Additionally, if both ``answer_texts``
and ``char_span_starts`` are given, the ``Instance`` has ``span_start`` and ``span_end``
fields, which are both ``IndexFields``.
Parameters
----------
question_list_tokens : ``List[List[Token]]``
An already-tokenized list of questions. Each dialog have multiple questions.
passage_tokens : ``List[Token]``
An already-tokenized passage that contains the answer to the given question.
token_indexers : ``Dict[str, TokenIndexer]``
Determines how the question and passage ``TextFields`` will be converted into tensors that
get input to a model. See :class:`TokenIndexer`.
passage_text : ``str``
The original passage text. We need this so that we can recover the actual span from the
original passage that the model predicts as the answer to the question. This is used in
official evaluation scripts.
token_span_lists : ``List[List[Tuple[int, int]]]``, optional
Indices into ``passage_tokens`` to use as the answer to the question for training. This is
a list of list, first because there is multiple questions per dialog, and
because there might be several possible correct answer spans in the passage.
Currently, we just select the last span in this list (i.e., QuAC has multiple
annotations on the dev set; this will select the last span, which was given by the original annotator).
yesno_list : ``List[int]``
List of the affirmation bit for each question answer pairs.
followup_list : ``List[int]``
List of the continuation bit for each question answer pairs.
num_context_answers : ``int``, optional
How many answers to encode into the passage.
additional_metadata : ``Dict[str, Any]``, optional
The constructed ``metadata`` field will by default contain ``original_passage``,
``token_offsets``, ``question_tokens``, ``passage_tokens``, and ``answer_texts`` keys. If
you want any other metadata to be associated with each instance, you can pass that in here.
This dictionary will get added to the ``metadata`` dictionary we already construct.
"""
additional_metadata = additional_metadata or {}
fields: Dict[str, Field] = {}
passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
# This is separate so we can reference it later with a known type.
passage_field = TextField(passage_tokens, token_indexers)
fields["passage"] = passage_field
fields["question"] = ListField(
[TextField(q_tokens, token_indexers) for q_tokens in question_list_tokens]
)
metadata = {
"original_passage": passage_text,
"token_offsets": passage_offsets,
"question_tokens": [
[token.text for token in question_tokens] for question_tokens in question_list_tokens
],
"passage_tokens": [token.text for token in passage_tokens],
}
p1_answer_marker_list: List[Field] = []
p2_answer_marker_list: List[Field] = []
p3_answer_marker_list: List[Field] = []
def get_tag(i, i_name):
# Generate a tag to mark previous answer span in the passage.
return "<{0:d}_{1:s}>".format(i, i_name)
def mark_tag(span_start, span_end, passage_tags, prev_answer_distance):
try:
assert span_start >= 0
assert span_end >= 0
except: # noqa
raise ValueError(
"Previous {0:d}th answer span should have been updated!".format(
prev_answer_distance
)
)
# Modify "tags" to mark previous answer span.
if span_start == span_end:
passage_tags[prev_answer_distance][span_start] = get_tag(prev_answer_distance, "")
else:
passage_tags[prev_answer_distance][span_start] = get_tag(prev_answer_distance, "start")
passage_tags[prev_answer_distance][span_end] = get_tag(prev_answer_distance, "end")
for passage_index in range(span_start + 1, span_end):
passage_tags[prev_answer_distance][passage_index] = get_tag(
prev_answer_distance, "in"
)
if token_span_lists:
span_start_list: List[Field] = []
span_end_list: List[Field] = []
p1_span_start, p1_span_end, p2_span_start = -1, -1, -1
p2_span_end, p3_span_start, p3_span_end = -1, -1, -1
# Looping each <<answers>>.
for question_index, answer_span_lists in enumerate(token_span_lists):
span_start, span_end = answer_span_lists[-1] # Last one is the original answer
span_start_list.append(IndexField(span_start, passage_field))
span_end_list.append(IndexField(span_end, passage_field))
prev_answer_marker_lists = [
["O"] * len(passage_tokens),
["O"] * len(passage_tokens),
["O"] * len(passage_tokens),
["O"] * len(passage_tokens),
]
if question_index > 0 and num_context_answers > 0:
mark_tag(p1_span_start, p1_span_end, prev_answer_marker_lists, 1)
if question_index > 1 and num_context_answers > 1:
mark_tag(p2_span_start, p2_span_end, prev_answer_marker_lists, 2)
if question_index > 2 and num_context_answers > 2:
mark_tag(p3_span_start, p3_span_end, prev_answer_marker_lists, 3)
p3_span_start = p2_span_start
p3_span_end = p2_span_end
p2_span_start = p1_span_start
p2_span_end = p1_span_end
p1_span_start = span_start
p1_span_end = span_end
if num_context_answers > 2:
p3_answer_marker_list.append(
SequenceLabelField(
prev_answer_marker_lists[3], passage_field, label_namespace="answer_tags"
)
)
if num_context_answers > 1:
p2_answer_marker_list.append(
SequenceLabelField(
prev_answer_marker_lists[2], passage_field, label_namespace="answer_tags"
)
)
if num_context_answers > 0:
p1_answer_marker_list.append(
SequenceLabelField(
prev_answer_marker_lists[1], passage_field, label_namespace="answer_tags"
)
)
fields["span_start"] = ListField(span_start_list)
fields["span_end"] = ListField(span_end_list)
if num_context_answers > 0:
fields["p1_answer_marker"] = ListField(p1_answer_marker_list)
if num_context_answers > 1:
fields["p2_answer_marker"] = ListField(p2_answer_marker_list)
if num_context_answers > 2:
fields["p3_answer_marker"] = ListField(p3_answer_marker_list)
fields["yesno_list"] = ListField(
[LabelField(yesno, label_namespace="yesno_labels") for yesno in yesno_list]
)
fields["followup_list"] = ListField(
[LabelField(followup, label_namespace="followup_labels") for followup in followup_list]
)
metadata.update(additional_metadata)
fields["metadata"] = MetadataField(metadata)
return Instance(fields)
def handle_cannot(reference_answers: List[str]):
"""
Process a list of reference answers.
If equal or more than half of the reference answers are "CANNOTANSWER", take it as gold.
Otherwise, return answers that are not "CANNOTANSWER".
"""
num_cannot = 0
num_spans = 0
for ref in reference_answers:
if ref == "CANNOTANSWER":
num_cannot += 1
else:
num_spans += 1
if num_cannot >= num_spans:
reference_answers = ["CANNOTANSWER"]
else:
reference_answers = [x for x in reference_answers if x != "CANNOTANSWER"]
return reference_answers
def split_token_by_delimiter(token: Token, delimiter: str) -> List[Token]:
split_tokens = []
char_offset = token.idx
for sub_str in token.text.split(delimiter):
if sub_str:
split_tokens.append(Token(text=sub_str, idx=char_offset))
char_offset += len(sub_str)
split_tokens.append(Token(text=delimiter, idx=char_offset))
char_offset += len(delimiter)
if split_tokens:
split_tokens.pop(-1)
char_offset -= len(delimiter)
return split_tokens
else:
return [token]
def split_tokens_by_hyphen(tokens: List[Token]) -> List[Token]:
hyphens = ["-", "–", "~"]
new_tokens: List[Token] = []
for token in tokens:
if any(hyphen in token.text for hyphen in hyphens):
unsplit_tokens = [token]
split_tokens: List[Token] = []
for hyphen in hyphens:
for unsplit_token in unsplit_tokens:
if hyphen in token.text:
split_tokens += split_token_by_delimiter(unsplit_token, hyphen)
else:
split_tokens.append(unsplit_token)
unsplit_tokens, split_tokens = split_tokens, []
new_tokens += unsplit_tokens
else:
new_tokens.append(token)
return new_tokens