diff --git a/CHANGELOG.md b/CHANGELOG.md index f988fbe2b..d53d235d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in `experiment_from_huggingface.jsonnet` and `experiment.jsonnet` by changing `min_count` to have key `labels` instead of `answers`. Resolves failure of model checks that involve calling `_extend` in `vocabulary.py` - `TransformerQA` now outputs span probabilities as well as scores. +- `TransformerQAPredictor` now implements `predictions_to_labeled_instances`, which is required for the interpret module. ### Added diff --git a/allennlp_models/rc/predictors/transformer_qa.py b/allennlp_models/rc/predictors/transformer_qa.py index 2f18802e6..e98e9a490 100644 --- a/allennlp_models/rc/predictors/transformer_qa.py +++ b/allennlp_models/rc/predictors/transformer_qa.py @@ -1,12 +1,18 @@ -from typing import List, Dict, Any +import logging +from typing import List, Dict, Any, Optional -from allennlp.models import Model from overrides import overrides +import numpy +from allennlp.models import Model from allennlp.common.util import JsonDict, sanitize from allennlp.data import Instance, DatasetReader from allennlp.predictors.predictor import Predictor +from allennlp.data.fields import SpanField + +logger = logging.getLogger(__name__) + @Predictor.register("transformer_qa") class TransformerQAPredictor(Predictor): @@ -47,23 +53,50 @@ def predict_json(self, inputs: JsonDict) -> JsonDict: assert len(results) == 1 return results[0] + @overrides + def predictions_to_labeled_instances( + self, instance: Instance, outputs: Dict[str, numpy.ndarray] + ) -> List[Instance]: + new_instance = instance.duplicate() + span_start = int(outputs["best_span"][0]) + span_end = int(outputs["best_span"][1]) + + start_of_context = ( + len(self._dataset_reader._tokenizer.sequence_pair_start_tokens) + + len(instance["metadata"]["question_tokens"]) + + len(self._dataset_reader._tokenizer.sequence_pair_mid_tokens) + ) + + answer_span = SpanField( + start_of_context + span_start, + start_of_context + span_end, + instance["question_with_context"], + ) + new_instance.add_field("answer_span", answer_span) + return [new_instance] + @overrides def _json_to_instance(self, json_dict: JsonDict) -> Instance: - raise NotImplementedError( - "This predictor maps a question to multiple instances. " - "Please use _json_to_instances instead." + logger.warning( + "This method is implemented only for use in interpret modules." + "The predictor maps a question to multiple instances. " + "Please use _json_to_instances instead for all non-interpret uses. " ) + return self._json_to_instances(json_dict, qid=-1)[0] - def _json_to_instances(self, json_dict: JsonDict) -> List[Instance]: + def _json_to_instances(self, json_dict: JsonDict, qid: Optional[int] = None) -> List[Instance]: # We allow the passage / context to be specified with either key. # But we do it this way so that a 'KeyError: context' exception will be raised # when neither key is specified, since the 'context' key is the default and # the 'passage' key was only added to be compatible with the input for other # RC models. + # if `qid` is `None`, it is updated using self._next_qid context = json_dict["passage"] if "passage" in json_dict else json_dict["context"] result: List[Instance] = [] + question_id = qid or self._next_qid + for instance in self._dataset_reader.make_instances( - qid=str(self._next_qid), + qid=str(question_id), question=json_dict["question"], answers=[], context=context, @@ -72,7 +105,8 @@ def _json_to_instances(self, json_dict: JsonDict) -> List[Instance]: ): self._dataset_reader.apply_token_indexers(instance) result.append(instance) - self._next_qid += 1 + if qid is None: + self._next_qid += 1 return result @overrides