Skip to content

Commit

Permalink
Merge pull request #260 from michaelfeil/better-bettertransformer-int…
Browse files Browse the repository at this point in the history
…erface

impove bettertransformer interface
  • Loading branch information
michaelfeil authored Jun 10, 2024
2 parents f7bcc60 + 3075c4b commit ce3e96c
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 18 deletions.
4 changes: 4 additions & 0 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ async def _models():
),
capabilities=engine.capabilities,
backend=engine_args.engine.name,
embedding_dtype=engine_args.embedding_dtype.name,
dtype=engine_args.dtype.name,
revision=engine_args.revision,
lengths_via_tokenize=engine_args.lengths_via_tokenize,
device=engine_args.device.name,
)
)
Expand Down
18 changes: 17 additions & 1 deletion libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import TYPE_CHECKING

from infinity_emb._optional_imports import CHECK_OPTIMUM
from infinity_emb.primitives import Device

if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
Expand All @@ -13,8 +14,23 @@

from transformers import PreTrainedModel # type: ignore[import-untyped]

from infinity_emb.args import EngineArgs


def to_bettertransformer(
model: "PreTrainedModel", engine_args: "EngineArgs", logger: "Logger"
):
if not engine_args.bettertransformer:
return model

if engine_args.device == Device.mps or (
hasattr(model, "device") and model.device.type == "mps"
):
logger.warning(
"BetterTransformer is not available for MPS device. Continue without bettertransformer modeling code."
)
return model

def to_bettertransformer(model: "PreTrainedModel", logger: "Logger"):
if os.environ.get("INFINITY_DISABLE_OPTIMUM", False): # OLD VAR
logger.warning(
"DEPRECATED `INFINITY_DISABLE_OPTIMUM` - setting optimizations via BetterTransformer,"
Expand Down
11 changes: 5 additions & 6 deletions libs/infinity_emb/infinity_emb/transformer/classifier/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from infinity_emb._optional_imports import CHECK_TRANSFORMERS
from infinity_emb.args import EngineArgs
from infinity_emb.log_handler import logger
from infinity_emb.primitives import Device
from infinity_emb.transformer.abstract import BaseClassifer
from infinity_emb.transformer.acceleration import to_bettertransformer

Expand Down Expand Up @@ -31,11 +30,11 @@ def __init__(
if self._pipe.device.type != "cpu": # and engine_args.dtype == "float16":
self._pipe.model = self._pipe.model.half()

if not (engine_args.device == Device.mps or not engine_args.bettertransformer):
self._pipe.model = to_bettertransformer(
self._pipe.model,
logger,
)
self._pipe.model = to_bettertransformer(
self._pipe.model,
engine_args,
logger,
)

self._infinity_tokenizer = AutoTokenizer.from_pretrained(
engine_args.model_name_or_path,
Expand Down
10 changes: 5 additions & 5 deletions libs/infinity_emb/infinity_emb/transformer/crossencoder/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ def __init__(self, *, engine_args: EngineArgs):
self._infinity_tokenizer = copy.deepcopy(self.tokenizer)
self.model.eval() # type: ignore

if not (self._target_device.type == "mps" or not engine_args.bettertransformer):
self.model = to_bettertransformer(
self.model, # type: ignore
logger,
)
self.model = to_bettertransformer(
self.model, # type: ignore
engine_args,
logger,
)

if self._target_device.type == "cuda" and engine_args.dtype in [
Dtype.auto,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def __init__(self, *, engine_args=EngineArgs):
self.eval()
self.engine_args = engine_args

if not (self.device.type == "mps" or not engine_args.bettertransformer):
fm.auto_model = to_bettertransformer(
fm.auto_model,
logger,
)
fm.auto_model = to_bettertransformer(
fm.auto_model,
engine_args,
logger,
)

if self.device.type == "cuda" and engine_args.dtype in [
Dtype.auto,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@ class ClipLikeModel(BaseClipVisionModel):
def __init__(self, *, engine_args: EngineArgs):
CHECK_TORCH.mark_required()
CHECK_TRANSFORMERS.mark_required()

self.model = AutoModel.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
trust_remote_code=engine_args.trust_remote_code,
# attn_implementation="eager" if engine_args.bettertransformer else None,
)
if torch.cuda.is_available():
self.model = self.model.cuda()
if engine_args.dtype in (Dtype.float16, Dtype.auto):
self.model = self.model.half()
# self.model = to_bettertransformer(
# self.model,
# engine_args,
# logger,
# )
self.processor = AutoProcessor.from_pretrained(
engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down

0 comments on commit ce3e96c

Please sign in to comment.