diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index c37c12ca2a8939..f25baf2cf1b6e5 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -1,13 +1,17 @@ from functools import partial, reduce -from typing import Callable, Dict, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union import transformers -from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available +from .. import PretrainedConfig, is_tf_available, is_torch_available from ..utils import logging from .config import OnnxConfig +if TYPE_CHECKING: + from transformers import PreTrainedModel, TFPreTrainedModel + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_torch_available(): @@ -505,7 +509,7 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type: @staticmethod def get_model_from_feature( feature: str, model: str, framework: str = "pt", cache_dir: str = None - ) -> Union[PreTrainedModel, TFPreTrainedModel]: + ) -> Union["PreTrainedModel", "TFPreTrainedModel"]: """ Attempts to retrieve a model from a model's name and the feature to be enabled. @@ -533,7 +537,7 @@ def get_model_from_feature( @staticmethod def check_supported_model_or_raise( - model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default" + model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default" ) -> Tuple[str, Callable]: """ Check whether or not the model has the requested features.