Skip to content

Commit

Permalink
Remove imports and use forward references in ONNX feature (#17926)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored Jun 29, 2022
1 parent 5cdfff5 commit 47b9165
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/transformers/onnx/features.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from functools import partial, reduce
from typing import Callable, Dict, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type, Union

import transformers

from .. import PretrainedConfig, PreTrainedModel, TFPreTrainedModel, is_tf_available, is_torch_available
from .. import PretrainedConfig, is_tf_available, is_torch_available
from ..utils import logging
from .config import OnnxConfig


if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel


logger = logging.get_logger(__name__) # pylint: disable=invalid-name

if is_torch_available():
Expand Down Expand Up @@ -505,7 +509,7 @@ def get_model_class_for_feature(feature: str, framework: str = "pt") -> Type:
@staticmethod
def get_model_from_feature(
feature: str, model: str, framework: str = "pt", cache_dir: str = None
) -> Union[PreTrainedModel, TFPreTrainedModel]:
) -> Union["PreTrainedModel", "TFPreTrainedModel"]:
"""
Attempts to retrieve a model from a model's name and the feature to be enabled.
Expand Down Expand Up @@ -533,7 +537,7 @@ def get_model_from_feature(

@staticmethod
def check_supported_model_or_raise(
model: Union[PreTrainedModel, TFPreTrainedModel], feature: str = "default"
model: Union["PreTrainedModel", "TFPreTrainedModel"], feature: str = "default"
) -> Tuple[str, Callable]:
"""
Check whether or not the model has the requested features.
Expand Down

0 comments on commit 47b9165

Please sign in to comment.