diff --git a/libs/infinity_emb/infinity_emb/infinity_server.py b/libs/infinity_emb/infinity_emb/infinity_server.py index 37440ca5..12c0c3ad 100644 --- a/libs/infinity_emb/infinity_emb/infinity_server.py +++ b/libs/infinity_emb/infinity_emb/infinity_server.py @@ -159,6 +159,10 @@ async def _models(): ), capabilities=engine.capabilities, backend=engine_args.engine.name, + embedding_dtype=engine_args.embedding_dtype.name, + dtype=engine_args.dtype.name, + revision=engine_args.revision, + lengths_via_tokenize=engine_args.lengths_via_tokenize, device=engine_args.device.name, ) ) diff --git a/libs/infinity_emb/infinity_emb/transformer/acceleration.py b/libs/infinity_emb/infinity_emb/transformer/acceleration.py index 2ed8f009..c0b88b9e 100644 --- a/libs/infinity_emb/infinity_emb/transformer/acceleration.py +++ b/libs/infinity_emb/infinity_emb/transformer/acceleration.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from infinity_emb._optional_imports import CHECK_OPTIMUM +from infinity_emb.primitives import Device if CHECK_OPTIMUM.is_available: from optimum.bettertransformer import ( # type: ignore[import-untyped] @@ -13,8 +14,23 @@ from transformers import PreTrainedModel # type: ignore[import-untyped] + from infinity_emb.args import EngineArgs + + +def to_bettertransformer( + model: "PreTrainedModel", engine_args: "EngineArgs", logger: "Logger" +): + if not engine_args.bettertransformer: + return model + + if engine_args.device == Device.mps or ( + hasattr(model, "device") and model.device.type == "mps" + ): + logger.warning( + "BetterTransformer is not available for MPS device. Continue without bettertransformer modeling code." + ) + return model -def to_bettertransformer(model: "PreTrainedModel", logger: "Logger"): if os.environ.get("INFINITY_DISABLE_OPTIMUM", False): # OLD VAR logger.warning( "DEPRECATED `INFINITY_DISABLE_OPTIMUM` - setting optimizations via BetterTransformer," diff --git a/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py b/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py index 8dc1b433..f5c96da6 100644 --- a/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/classifier/torch.py @@ -1,7 +1,6 @@ from infinity_emb._optional_imports import CHECK_TRANSFORMERS from infinity_emb.args import EngineArgs from infinity_emb.log_handler import logger -from infinity_emb.primitives import Device from infinity_emb.transformer.abstract import BaseClassifer from infinity_emb.transformer.acceleration import to_bettertransformer @@ -31,11 +30,11 @@ def __init__( if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16": self._pipe.model = self._pipe.model.half() - if not (engine_args.device == Device.mps or not engine_args.bettertransformer): - self._pipe.model = to_bettertransformer( - self._pipe.model, - logger, - ) + self._pipe.model = to_bettertransformer( + self._pipe.model, + engine_args, + logger, + ) self._infinity_tokenizer = AutoTokenizer.from_pretrained( engine_args.model_name_or_path, diff --git a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py index 49e9de26..e3cc9d53 100644 --- a/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py +++ b/libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py @@ -55,11 +55,11 @@ def __init__(self, *, engine_args: EngineArgs): self._infinity_tokenizer = copy.deepcopy(self.tokenizer) self.model.eval() # type: ignore - if not (self._target_device.type == "mps" or not engine_args.bettertransformer): - self.model = to_bettertransformer( - self.model, # type: ignore - logger, - ) + self.model = to_bettertransformer( + self.model, # type: ignore + engine_args, + logger, + ) if self._target_device.type == "cuda" and engine_args.dtype in [ Dtype.auto, diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py index c6785351..e28047d1 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/sentence_transformer.py @@ -68,11 +68,11 @@ def __init__(self, *, engine_args=EngineArgs): self.eval() self.engine_args = engine_args - if not (self.device.type == "mps" or not engine_args.bettertransformer): - fm.auto_model = to_bettertransformer( - fm.auto_model, - logger, - ) + fm.auto_model = to_bettertransformer( + fm.auto_model, + engine_args, + logger, + ) if self.device.type == "cuda" and engine_args.dtype in [ Dtype.auto, diff --git a/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py b/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py index 6c372713..57f51ac5 100644 --- a/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py +++ b/libs/infinity_emb/infinity_emb/transformer/vision/torch_vision.py @@ -24,16 +24,21 @@ class ClipLikeModel(BaseClipVisionModel): def __init__(self, *, engine_args: EngineArgs): CHECK_TORCH.mark_required() CHECK_TRANSFORMERS.mark_required() - self.model = AutoModel.from_pretrained( engine_args.model_name_or_path, revision=engine_args.revision, trust_remote_code=engine_args.trust_remote_code, + # attn_implementation="eager" if engine_args.bettertransformer else None, ) if torch.cuda.is_available(): self.model = self.model.cuda() if engine_args.dtype in (Dtype.float16, Dtype.auto): self.model = self.model.half() + # self.model = to_bettertransformer( + # self.model, + # engine_args, + # logger, + # ) self.processor = AutoProcessor.from_pretrained( engine_args.model_name_or_path, revision=engine_args.revision,