Skip to content

Commit

Permalink
Merge pull request #20 from michaelfeil/add-optimum
Browse files Browse the repository at this point in the history
add optimum dependencies
  • Loading branch information
michaelfeil authored Oct 30, 2023
2 parents 180de84 + b7a4d0e commit 3360b6e
Show file tree
Hide file tree
Showing 9 changed files with 1,463 additions and 565 deletions.
9 changes: 6 additions & 3 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from infinity_emb.fastapi_schemas.pymodels import OpenAIEmbeddingResult
from typing import Any, Dict, Iterable, Union

from infinity_emb.inference.primitives import NpEmbeddingType


def list_embeddings_to_response(
embeddings: NpEmbeddingType, model: str, usage: int
) -> OpenAIEmbeddingResult:
embeddings: Union[NpEmbeddingType, Iterable[NpEmbeddingType]],
model: str,
usage: int,
) -> Dict[str, Any]:
return dict(
model=model,
data=[
Expand Down
17 changes: 9 additions & 8 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import bisect
import os
import queue
import random
import threading
Expand Down Expand Up @@ -38,7 +39,7 @@ async def extend(self, items: List[PrioritizedQueueItem]):

self._sync_event.set()

def pop_optimal_batch(
def pop_optimal_batches(
self, size: int, timeout=0.2, latest_first=False
) -> Union[List[List[EmbeddingResult]], None]:
"""
Expand Down Expand Up @@ -97,7 +98,7 @@ async def extend(self, items: List[PrioritizedQueueItem]):
with self._lock_queue_event:
self._queue.extend(items)

def pop_optimal_batch(
def pop_optimal_batches(
self, size: int, timeout=0.2, **kwargs
) -> Union[List[List[EmbeddingResult]], None]:
"""
Expand All @@ -119,7 +120,7 @@ def pop_optimal_batch(
return None

# slice as many batches as possible
n_batches = max(1, len(self._queue) // size)
n_batches = min(32, max(1, len(self._queue) // size))
size_batches = size * n_batches

with self._lock_queue_event:
Expand All @@ -133,11 +134,11 @@ def pop_optimal_batch(
# optimal padding per batch
new_items_l.sort()

new_items = []
new_items: List[List[EmbeddingResult]] = []
for i in range(n_batches):
mini_batch = new_items_l[size * i : size * (i + 1)]
mini_batch = [mi.item for mi in mini_batch]
new_items.append(mini_batch)
mini_batch_e: List[EmbeddingResult] = [mi.item for mi in mini_batch]
new_items.append(mini_batch_e)
# runtime checks
# assert 1 <= len(mini_batch) <= size
# if i > 0:
Expand Down Expand Up @@ -170,7 +171,7 @@ def __init__(
self,
model: BaseTransformer,
max_batch_size: int,
max_queue_wait: int = 32_000,
max_queue_wait: int = int(os.environ.get("INFINITY_QUEUE_SIZE", 32_000)),
batch_delay: float = 5e-3,
verbose=False,
) -> None:
Expand Down Expand Up @@ -273,7 +274,7 @@ def _preprocess_batch(self):
# decision to attemp to pop a batch
# -> will happen if a single datapoint is available

batches = self._queue_prio.pop_optimal_batch(
batches = self._queue_prio.pop_optimal_batches(
self.max_batch_size, latest_first=False
)
if not batches:
Expand Down
19 changes: 1 addition & 18 deletions libs/infinity_emb/infinity_emb/inference/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@

import numpy as np

# from infinity_emb.inference.threading_asyncio import EventTS

NpEmbeddingType = np.ndarray


@dataclass(order=True)
class EmbeddingResult:
sentence: str = field(compare=False)
future: asyncio.Future = field(compare=False)
uuid: str = field(default_factory=lambda: str(uuid4()), compare=False)
embedding: Optional[NpEmbeddingType] = field(default=None, compare=False)
future: Optional[asyncio.Future] = field(default=None, compare=False)
# event: Optional[EventTS] = field(default=None, compare=False)


@dataclass(order=True)
Expand All @@ -32,17 +29,3 @@ class OverloadStatus:
queue_fraction: float
queue_absolute: int
results_absolute: int


if __name__ == "__main__":
import bisect
from concurrent.futures import ThreadPoolExecutor

tp = ThreadPoolExecutor()
r1 = EmbeddingResult(5, "hello")
r2 = EmbeddingResult(6, "hello_")
r3 = EmbeddingResult(6, "hello_")
r1 < r2
l1 = []
bisect.insort(l1, r1)
bisect.insort(l1, r2)
5 changes: 3 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/select_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from time import perf_counter
from typing import Tuple

from infinity_emb.inference.primitives import EmbeddingResult, NpEmbeddingType
from infinity_emb.log_handler import logger
Expand Down Expand Up @@ -38,8 +39,8 @@ def select_model_to_functional(

def runtime_check_callable(
model: BaseTransformer, sample=["warmup"], log=True
) -> float:
inp = [EmbeddingResult(sentence=s) for s in sample] # type: ignore
) -> Tuple[float, float]:
inp = [EmbeddingResult(sentence=s, future=None) for s in sample] # type: ignore
start = perf_counter()
sentences = [item.sentence for item in inp]
feat = model.encode_pre(sentences)
Expand Down
35 changes: 32 additions & 3 deletions libs/infinity_emb/infinity_emb/transformer/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from infinity_emb.log_handler import logger
from infinity_emb.transformer.abstract import BaseTransformer

try:
from optimum.bettertransformer import BetterTransformer

OPTIMUM_AVAILABLE = True
except ImportError:
OPTIMUM_AVAILABLE = False

__all__ = [
"SentenceTransformerPatched",
"CT2SentenceTransformer",
Expand All @@ -23,12 +30,34 @@ class SentenceTransformerPatched(SentenceTransformer, BaseTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
device = self._target_device
self.eval()
self.to(device)
# make a copy of the tokenizer,
# to be able to could the tokens in another thread
# without corrupting the original.
self._infinity_tokenizer = copy.deepcopy(self._first_module().tokenizer)
fm = self._first_module()
self._infinity_tokenizer = copy.deepcopy(fm.tokenizer)
if OPTIMUM_AVAILABLE and not os.environ.get("INFINITY_DISABLE_OPTIMUM", False):
logger.info(
"Adding optimizations via Huggingface optimum. "
"Disable by setting the env var `INFINITY_DISABLE_OPTIMUM`"
)
try:
fm.auto_model = BetterTransformer.transform(fm.auto_model)
except Exception as ex:
logger.exception(f"BetterTransformer failed with {ex}")
exit(1)
else:
logger.info("No optimizations via Huggingface optimum.")

self.eval()
if self._target_device.type == "cuda" and os.environ.get(
"INFINITY_TORCH_ENABLE_HALF", False
):
logger.info(
"Switching to half() precision (fp16)."
"Enabled by the setting the env var `INFINITY_TORCH_ENABLE_HALF`"
)
self.half()

def encode_pre(self, sentences) -> Dict[str, Tensor]:
features = self.tokenize(sentences)
Expand Down Expand Up @@ -108,7 +137,7 @@ def __init__(
compute_type="default",
force=False,
vmap: Union[str, None] = None,
**kwargs
**kwargs,
):
super().__init__(*args, **kwargs)
self[0] = CT2Transformer(
Expand Down
Loading

0 comments on commit 3360b6e

Please sign in to comment.