From d9ee1d67fefc1ae0d97f95598b496a28c4ed4f45 Mon Sep 17 00:00:00 2001 From: calpt <36051308+calpt@users.noreply.github.com> Date: Tue, 10 Nov 2020 12:44:16 +0100 Subject: [PATCH] Support XModelWithHeads in pipelines --- src/transformers/pipelines.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/pipelines.py b/src/transformers/pipelines.py index 45331fe4e9..8757532615 100755 --- a/src/transformers/pipelines.py +++ b/src/transformers/pipelines.py @@ -29,6 +29,8 @@ import numpy as np +from transformers.modeling_auto import MODEL_WITH_HEADS_MAPPING + from .configuration_auto import AutoConfig from .configuration_utils import PretrainedConfig from .data import SquadExample, squad_convert_examples_to_features @@ -614,6 +616,8 @@ def check_model_type(self, supported_models: Union[List[str], dict]): """ if not isinstance(supported_models, list): # Create from a model mapping supported_models = [item[1].__name__ for item in supported_models.items()] + for item in MODEL_WITH_HEADS_MAPPING.values(): + supported_models.append(item.__name__) if self.model.__class__.__name__ not in supported_models: raise PipelineException( self.task,