Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dynamic batching embedding/reranking #774

Merged
merged 43 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
43da415
draft static batching embedding/reranking on single gaudi card
Spycsh Oct 9, 2024
36d3d95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2024
e72938d
fix
Spycsh Oct 9, 2024
8bd8dbd
resolve segfault, deadlock and other issues
Spycsh Oct 10, 2024
e036caf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2024
3464641
narrow down default timeout
Spycsh Oct 10, 2024
18fa6c8
Merge branch 'main' into static_batching
Spycsh Oct 11, 2024
353815a
add doockerfile
Spycsh Oct 11, 2024
24f3c3d
fix hpu local microservice start
Spycsh Oct 11, 2024
efc37d8
openai format
Spycsh Oct 11, 2024
e20ac3b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 11, 2024
48a4a52
configurable timeout
Spycsh Oct 12, 2024
8183f3b
lower timeout
Spycsh Oct 12, 2024
2ebc439
fix
Spycsh Oct 12, 2024
43633df
lower default timeout
Spycsh Oct 12, 2024
b51d219
bf16
Spycsh Oct 12, 2024
08d5c05
log, pad max_len
Spycsh Oct 12, 2024
7329404
autocast, 128
Spycsh Oct 12, 2024
6396e54
fix acc issue
Spycsh Oct 14, 2024
9df1e8b
perf fallback with no acc drop
Spycsh Oct 14, 2024
8ae56eb
revert no-padding ones
Spycsh Oct 14, 2024
8ff9a59
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
08b11ec
fix hpu graph wrapper
Spycsh Oct 14, 2024
33698f9
Merge branch 'static_batching' of https://github.com/Spycsh/GenAIComp…
Spycsh Oct 14, 2024
6492e70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2024
8f8861c
add padding batch
Spycsh Oct 15, 2024
268600a
Merge branch 'static_batching' of https://github.com/Spycsh/GenAIComp…
Spycsh Oct 15, 2024
4411462
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
51c3306
habana 1.18
Spycsh Oct 15, 2024
357f302
static -> dynamic
Spycsh Oct 15, 2024
c3aaed7
add UT, add param in_single_process
Spycsh Oct 15, 2024
a556581
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
fda14f5
add docker file
Spycsh Oct 15, 2024
b7d8a5e
Merge branch 'static_batching' of https://github.com/Spycsh/GenAIComp…
Spycsh Oct 15, 2024
7ac3324
Merge branch 'main' into static_batching
Spycsh Oct 15, 2024
d9ac817
fix case doc empty, and pass model id from env
Spycsh Oct 15, 2024
a337c2d
Merge branch 'static_batching' of https://github.com/Spycsh/GenAIComp…
Spycsh Oct 15, 2024
ab21b1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 15, 2024
d5afdc6
Merge branch 'main' into static_batching
chensuyue Oct 17, 2024
f42e7d8
Merge branch 'main' into static_batching
Spycsh Oct 18, 2024
4fd090d
CI
Spycsh Oct 18, 2024
3fbe230
Merge branch 'static_batching' of https://github.com/Spycsh/GenAIComp…
Spycsh Oct 18, 2024
8526275
Merge branch 'main' into static_batching
ZePan110 Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/docker/compose/embeddings-compose-cd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ services:
build:
dockerfile: comps/embeddings/predictionguard/Dockerfile
image: ${REGISTRY:-opea}/embedding-predictionguard:${TAG:-latest}
embedding-reranking-local:
build:
dockerfile: comps/embeddings/tei/langchain/Dockerfile.dynamic_batching
image: ${REGISTRY:-opea}/embedding-reranking-local:${TAG:-latest}
69 changes: 66 additions & 3 deletions comps/cores/mega/micro_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@

import asyncio
import multiprocessing
import os
from collections import defaultdict, deque
from enum import Enum
from typing import Any, List, Optional, Type

from ..proto.docarray import TextDoc
from .constants import ServiceRoleType, ServiceType
from .logger import CustomLogger
from .utils import check_ports_availability

opea_microservices = {}

logger = CustomLogger("micro_service")
logflag = os.getenv("LOGFLAG", False)


class MicroService:
"""MicroService class to create a microservice."""
Expand All @@ -31,6 +38,9 @@
provider: Optional[str] = None,
provider_endpoint: Optional[str] = None,
use_remote_service: Optional[bool] = False,
dynamic_batching: bool = False,
dynamic_batching_timeout: int = 1,
dynamic_batching_max_batch_size: int = 32,
):
"""Init the microservice."""
self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__
Expand All @@ -43,6 +53,9 @@
self.input_datatype = input_datatype
self.output_datatype = output_datatype
self.use_remote_service = use_remote_service
self.dynamic_batching = dynamic_batching
self.dynamic_batching_timeout = dynamic_batching_timeout
self.dynamic_batching_max_batch_size = dynamic_batching_max_batch_size
self.uvicorn_kwargs = {}

if ssl_keyfile:
Expand All @@ -58,10 +71,50 @@

self.server = self._get_server()
self.app = self.server.app
# create a batch request processor loop if using dynamic batching
if self.dynamic_batching:
self.buffer_lock = asyncio.Lock()
self.request_buffer = defaultdict(deque)

@self.app.on_event("startup")
async def startup_event():
asyncio.create_task(self._dynamic_batch_processor())

self.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.event_loop)
self.event_loop.run_until_complete(self._async_setup())

async def _dynamic_batch_processor(self):
if logflag:
logger.info("dynamic batch processor looping...")

Check warning on line 89 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L89

Added line #L89 was not covered by tests
while True:
await asyncio.sleep(self.dynamic_batching_timeout)
runtime_batch: dict[Enum, list[dict]] = {} # {ServiceType.Embedding: [{"request": xx, "response": yy}, {}]}

async with self.buffer_lock:
# prepare the runtime batch, access to buffer is locked
if self.request_buffer:
for service_type, request_lst in self.request_buffer.items():
batch = []

Check warning on line 98 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L97-L98

Added lines #L97 - L98 were not covered by tests
# grab min(MAX_BATCH_SIZE, REQUEST_SIZE) requests from buffer
for _ in range(min(self.dynamic_batching_max_batch_size, len(request_lst))):
batch.append(request_lst.popleft())

Check warning on line 101 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L100-L101

Added lines #L100 - L101 were not covered by tests

runtime_batch[service_type] = batch

Check warning on line 103 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L103

Added line #L103 was not covered by tests

# Run batched inference on the batch and set results
for service_type, batch in runtime_batch.items():
if not batch:
continue
results = await self.dynamic_batching_infer(service_type, batch)

Check warning on line 109 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L107-L109

Added lines #L107 - L109 were not covered by tests

for req, result in zip(batch, results):
req["response"].set_result(result)

Check warning on line 112 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L111-L112

Added lines #L111 - L112 were not covered by tests

async def dynamic_batching_infer(self, service_type: Enum, batch: list[dict]):
"""Need to implement."""
raise NotImplementedError("Unimplemented dynamic batching inference!")

Check warning on line 116 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L116

Added line #L116 was not covered by tests

def _validate_env(self):
"""Check whether to use the microservice locally."""
if self.use_remote_service:
Expand Down Expand Up @@ -116,10 +169,14 @@
self._validate_env()
self.event_loop.run_until_complete(self._async_run_forever())

def start(self):
def start(self, in_single_process=False):
self._validate_env()
self.process = multiprocessing.Process(target=self.run, daemon=False, name=self.name)
self.process.start()
if in_single_process:
# Resolve HPU segmentation fault and potential tokenizer issues by limiting to same process
self.run()

Check warning on line 176 in comps/cores/mega/micro_service.py

View check run for this annotation

Codecov / codecov/patch

comps/cores/mega/micro_service.py#L176

Added line #L176 was not covered by tests
else:
self.process = multiprocessing.Process(target=self.run, daemon=False, name=self.name)
self.process.start()

async def _async_teardown(self):
"""Shutdown the server."""
Expand Down Expand Up @@ -155,6 +212,9 @@
provider: Optional[str] = None,
provider_endpoint: Optional[str] = None,
methods: List[str] = ["POST"],
dynamic_batching: bool = False,
dynamic_batching_timeout: int = 1,
dynamic_batching_max_batch_size: int = 32,
):
def decorator(func):
if name not in opea_microservices:
Expand All @@ -172,6 +232,9 @@
output_datatype=output_datatype,
provider=provider,
provider_endpoint=provider_endpoint,
dynamic_batching=dynamic_batching,
dynamic_batching_timeout=dynamic_batching_timeout,
dynamic_batching_max_batch_size=dynamic_batching_max_batch_size,
)
opea_microservices[name] = micro_service
opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods)
Expand Down
28 changes: 28 additions & 0 deletions comps/embeddings/tei/langchain/Dockerfile.dynamic_batching
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

# FROM opea/habanalabs:1.16.1-pytorch-installer-2.2.2 as hpu
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest as hpu

RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \
libgl1-mesa-glx \
libjemalloc-dev

RUN useradd -m -s /bin/bash user && \
mkdir -p /home/user && \
chown -R user /home/user/

# Disable user for now
# USER user

COPY comps /home/user/comps

RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r /home/user/comps/embeddings/tei/langchain/requirements.txt && \
pip install git+https://github.com/huggingface/optimum-habana.git

ENV PYTHONPATH=$PYTHONPATH:/home/user

WORKDIR /home/user/comps/embeddings/tei/langchain

ENTRYPOINT ["python", "local_embedding_reranking.py"]
Loading
Loading