Skip to content

Commit

Permalink
latest updates
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelfeil committed Oct 30, 2023
1 parent a06af4e commit b7a4d0e
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 63 deletions.
9 changes: 6 additions & 3 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from infinity_emb.fastapi_schemas.pymodels import OpenAIEmbeddingResult
from typing import Any, Dict, Iterable, Union

from infinity_emb.inference.primitives import NpEmbeddingType


def list_embeddings_to_response(
embeddings: NpEmbeddingType, model: str, usage: int
) -> OpenAIEmbeddingResult:
embeddings: Union[NpEmbeddingType, Iterable[NpEmbeddingType]],
model: str,
usage: int,
) -> Dict[str, Any]:
return dict(
model=model,
data=[
Expand Down
14 changes: 7 additions & 7 deletions libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def extend(self, items: List[PrioritizedQueueItem]):

self._sync_event.set()

def pop_optimal_batch(
def pop_optimal_batches(
self, size: int, timeout=0.2, latest_first=False
) -> Union[List[List[EmbeddingResult]], None]:
"""
Expand Down Expand Up @@ -98,7 +98,7 @@ async def extend(self, items: List[PrioritizedQueueItem]):
with self._lock_queue_event:
self._queue.extend(items)

def pop_optimal_batch(
def pop_optimal_batches(
self, size: int, timeout=0.2, **kwargs
) -> Union[List[List[EmbeddingResult]], None]:
"""
Expand All @@ -120,7 +120,7 @@ def pop_optimal_batch(
return None

# slice as many batches as possible
n_batches = max(1, len(self._queue) // size)
n_batches = min(32, max(1, len(self._queue) // size))
size_batches = size * n_batches

with self._lock_queue_event:
Expand All @@ -134,11 +134,11 @@ def pop_optimal_batch(
# optimal padding per batch
new_items_l.sort()

new_items = []
new_items: List[List[EmbeddingResult]] = []
for i in range(n_batches):
mini_batch = new_items_l[size * i : size * (i + 1)]
mini_batch = [mi.item for mi in mini_batch]
new_items.append(mini_batch)
mini_batch_e: List[EmbeddingResult] = [mi.item for mi in mini_batch]
new_items.append(mini_batch_e)
# runtime checks
# assert 1 <= len(mini_batch) <= size
# if i > 0:
Expand Down Expand Up @@ -274,7 +274,7 @@ def _preprocess_batch(self):
# decision to attemp to pop a batch
# -> will happen if a single datapoint is available

batches = self._queue_prio.pop_optimal_batch(
batches = self._queue_prio.pop_optimal_batches(
self.max_batch_size, latest_first=False
)
if not batches:
Expand Down
19 changes: 1 addition & 18 deletions libs/infinity_emb/infinity_emb/inference/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@

import numpy as np

# from infinity_emb.inference.threading_asyncio import EventTS

NpEmbeddingType = np.ndarray


@dataclass(order=True)
class EmbeddingResult:
sentence: str = field(compare=False)
future: asyncio.Future = field(compare=False)
uuid: str = field(default_factory=lambda: str(uuid4()), compare=False)
embedding: Optional[NpEmbeddingType] = field(default=None, compare=False)
future: Optional[asyncio.Future] = field(default=None, compare=False)
# event: Optional[EventTS] = field(default=None, compare=False)


@dataclass(order=True)
Expand All @@ -32,17 +29,3 @@ class OverloadStatus:
queue_fraction: float
queue_absolute: int
results_absolute: int


if __name__ == "__main__":
import bisect
from concurrent.futures import ThreadPoolExecutor

tp = ThreadPoolExecutor()
r1 = EmbeddingResult(5, "hello")
r2 = EmbeddingResult(6, "hello_")
r3 = EmbeddingResult(6, "hello_")
r1 < r2
l1 = []
bisect.insort(l1, r1)
bisect.insort(l1, r2)
5 changes: 3 additions & 2 deletions libs/infinity_emb/infinity_emb/inference/select_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from time import perf_counter
from typing import Tuple

from infinity_emb.inference.primitives import EmbeddingResult, NpEmbeddingType
from infinity_emb.log_handler import logger
Expand Down Expand Up @@ -38,8 +39,8 @@ def select_model_to_functional(

def runtime_check_callable(
model: BaseTransformer, sample=["warmup"], log=True
) -> float:
inp = [EmbeddingResult(sentence=s) for s in sample] # type: ignore
) -> Tuple[float, float]:
inp = [EmbeddingResult(sentence=s, future=None) for s in sample] # type: ignore
start = perf_counter()
sentences = [item.sentence for item in inp]
feat = model.encode_pre(sentences)
Expand Down
22 changes: 7 additions & 15 deletions libs/infinity_emb/infinity_emb/transformer/sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class SentenceTransformerPatched(SentenceTransformer, BaseTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
device = self._target_device
self.eval()
self.to(device)
# make a copy of the tokenizer,
# to be able to could the tokens in another thread
Expand All @@ -50,16 +49,15 @@ def __init__(self, *args, **kwargs):
else:
logger.info("No optimizations via Huggingface optimum.")

if self._target_device.type == "cuda" and not os.environ.get(
"INFINITY_DISABLE_FLASH", False
self.eval()
if self._target_device.type == "cuda" and os.environ.get(
"INFINITY_TORCH_ENABLE_HALF", False
):
logger.info(
"Adding flash_attention."
"Disable by setting the env var `INFINITY_DISABLE_FLASH`"
"Switching to half() precision (fp16)."
"Enabled by the setting the env var `INFINITY_TORCH_ENABLE_HALF`"
)
self._use_flash_attn = True
else:
self._use_flash_attn = False
self.half()

def encode_pre(self, sentences) -> Dict[str, Tensor]:
features = self.tokenize(sentences)
Expand All @@ -74,13 +72,7 @@ def encode_core(self, features: Dict[str, Tensor]) -> Tensor:
with torch.inference_mode():
device = self._target_device
features = util.batch_to_device(features, device)
if self._use_flash_attn:
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
out_features = self.forward(features)["sentence_embedding"]
else:
out_features = self.forward(features)["sentence_embedding"]
out_features = self.forward(features)["sentence_embedding"]

return out_features

Expand Down
14 changes: 7 additions & 7 deletions libs/infinity_emb/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from infinity_emb.transformer.utils import InferenceEngine

PREFIX = "/v1_ct2"
model = pytest.DEFAULT_BERT_MODEL
model: str = pytest.DEFAULT_BERT_MODEL
batch_size = 64 if torch.cuda.is_available() else 8

app = create_server(
Expand Down
27 changes: 17 additions & 10 deletions libs/infinity_emb/tests/script_live.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import concurrent.futures
import json
import timeit
from functools import partial

import numpy as np
import requests
Expand All @@ -9,6 +11,7 @@


def embedding_live_performance():
tp = concurrent.futures.ThreadPoolExecutor()
sample = [f"Test count {i} {(list(range(i % (384))))} " for i in range(2048)]

json_d = json.dumps({"input": sample, "model": "model"})
Expand All @@ -21,28 +24,32 @@ def embedding_live_performance():
print(f"batch_size is {batch_size}, model={model_name}")
model = SentenceTransformer(model_name_or_path=model_name)

def local(data: str):
enc = model.encode(data, batch_size=batch_size)
assert len(enc) == len(data)
return enc
def local(data: list[str], iters=1):
data_in = data * iters
enc = model.encode(data_in, batch_size=batch_size)
assert len(enc) == len(data_in)
return enc[: len(data)]

def remote(json_data: bytes):
req = session.post(f"{LIVE_URL}/embeddings", data=json_data)
assert req.status_code == 200
return req
def remote(json_data: bytes, iters=1):
fn = partial(session.post, data=json_data)
req = list(tp.map(fn, [f"{LIVE_URL}/embeddings"] * iters))
assert req[0].status_code == 200
return req[0]

local_resp = local(sample)
remote_resp = [d["embedding"] for d in remote(json_d).json()["data"]]
np.testing.assert_almost_equal(local_resp, remote_resp, 6)
print("Both methods provide the identical output.")

print("Measuring latency via SentenceTransformers")
latency_st = timeit.timeit("local(sample)", number=2, globals=locals())
latency_st = timeit.timeit("local(sample, iters=5)", number=2, globals=locals())
print("SentenceTransformers latency: ", latency_st)
model = None

print("Measuring latency via requests")
latency_request = timeit.timeit("remote(json_d)", number=2, globals=locals())
latency_request = timeit.timeit(
"remote(json_d, iters=5)", number=2, globals=locals()
)
print(f"Request latency: {latency_request}")

assert latency_st * 1.1 > latency_request
Expand Down

0 comments on commit b7a4d0e

Please sign in to comment.