Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Making TransformerQAPredictor compatible with interpret modules #249

Merged
merged 4 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed bug in `experiment_from_huggingface.jsonnet` and `experiment.jsonnet` by changing `min_count` to have key `labels` instead of `answers`. Resolves failure of model checks that involve calling `_extend` in `vocabulary.py`
- `TransformerQA` now outputs span probabilities as well as scores.
- `TransformerQAPredictor` now implements `predictions_to_labeled_instances`, which is required for the interpret module.

### Added

Expand Down
50 changes: 42 additions & 8 deletions allennlp_models/rc/predictors/transformer_qa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import List, Dict, Any
import logging
from typing import List, Dict, Any, Optional

from allennlp.models import Model
from overrides import overrides
import numpy

from allennlp.models import Model
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import Instance, DatasetReader
from allennlp.predictors.predictor import Predictor

from allennlp.data.fields import SpanField

logger = logging.getLogger(__name__)


@Predictor.register("transformer_qa")
class TransformerQAPredictor(Predictor):
Expand Down Expand Up @@ -47,23 +53,50 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
assert len(results) == 1
return results[0]

@overrides
def predictions_to_labeled_instances(
self, instance: Instance, outputs: Dict[str, numpy.ndarray]
) -> List[Instance]:
new_instance = instance.duplicate()
span_start = int(outputs["best_span"][0])
span_end = int(outputs["best_span"][1])

start_of_context = (
len(self._dataset_reader._tokenizer.sequence_pair_start_tokens)
+ len(instance["metadata"]["question_tokens"])
+ len(self._dataset_reader._tokenizer.sequence_pair_mid_tokens)
)

answer_span = SpanField(
start_of_context + span_start,
start_of_context + span_end,
instance["question_with_context"],
)
new_instance.add_field("answer_span", answer_span)
return [new_instance]

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
raise NotImplementedError(
"This predictor maps a question to multiple instances. "
"Please use _json_to_instances instead."
logger.warning(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dirkgr I'm not sure if just adding a warning instead of an error actually breaks anything.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it breaks anything, but people don't read warnings.

"This method is implemented only for use in interpret modules."
"The predictor maps a question to multiple instances. "
"Please use _json_to_instances instead for all non-interpret uses. "
)
return self._json_to_instances(json_dict, qid=-1)[0]

def _json_to_instances(self, json_dict: JsonDict) -> List[Instance]:
def _json_to_instances(self, json_dict: JsonDict, qid: Optional[int] = None) -> List[Instance]:
# We allow the passage / context to be specified with either key.
# But we do it this way so that a 'KeyError: context' exception will be raised
# when neither key is specified, since the 'context' key is the default and
# the 'passage' key was only added to be compatible with the input for other
# RC models.
# if `qid` is `None`, it is updated using self._next_qid
context = json_dict["passage"] if "passage" in json_dict else json_dict["context"]
result: List[Instance] = []
question_id = qid or self._next_qid

for instance in self._dataset_reader.make_instances(
qid=str(self._next_qid),
qid=str(question_id),
question=json_dict["question"],
answers=[],
context=context,
Expand All @@ -72,7 +105,8 @@ def _json_to_instances(self, json_dict: JsonDict) -> List[Instance]:
):
self._dataset_reader.apply_token_indexers(instance)
result.append(instance)
self._next_qid += 1
if qid is None:
self._next_qid += 1
return result

@overrides
Expand Down