diff --git a/.github/workflows/docker/compose/embeddings-compose-cd.yaml b/.github/workflows/docker/compose/embeddings-compose-cd.yaml index 53243cfc5..a9d76fa0e 100644 --- a/.github/workflows/docker/compose/embeddings-compose-cd.yaml +++ b/.github/workflows/docker/compose/embeddings-compose-cd.yaml @@ -22,3 +22,7 @@ services: build: dockerfile: comps/embeddings/predictionguard/Dockerfile image: ${REGISTRY:-opea}/embedding-predictionguard:${TAG:-latest} + embedding-reranking-local: + build: + dockerfile: comps/embeddings/tei/langchain/Dockerfile.dynamic_batching + image: ${REGISTRY:-opea}/embedding-reranking-local:${TAG:-latest} diff --git a/comps/cores/mega/micro_service.py b/comps/cores/mega/micro_service.py index 89e4cd944..3552def42 100644 --- a/comps/cores/mega/micro_service.py +++ b/comps/cores/mega/micro_service.py @@ -3,14 +3,21 @@ import asyncio import multiprocessing +import os +from collections import defaultdict, deque +from enum import Enum from typing import Any, List, Optional, Type from ..proto.docarray import TextDoc from .constants import ServiceRoleType, ServiceType +from .logger import CustomLogger from .utils import check_ports_availability opea_microservices = {} +logger = CustomLogger("micro_service") +logflag = os.getenv("LOGFLAG", False) + class MicroService: """MicroService class to create a microservice.""" @@ -31,6 +38,9 @@ def __init__( provider: Optional[str] = None, provider_endpoint: Optional[str] = None, use_remote_service: Optional[bool] = False, + dynamic_batching: bool = False, + dynamic_batching_timeout: int = 1, + dynamic_batching_max_batch_size: int = 32, ): """Init the microservice.""" self.name = f"{name}/{self.__class__.__name__}" if name else self.__class__.__name__ @@ -43,6 +53,9 @@ def __init__( self.input_datatype = input_datatype self.output_datatype = output_datatype self.use_remote_service = use_remote_service + self.dynamic_batching = dynamic_batching + self.dynamic_batching_timeout = dynamic_batching_timeout + self.dynamic_batching_max_batch_size = dynamic_batching_max_batch_size self.uvicorn_kwargs = {} if ssl_keyfile: @@ -58,10 +71,50 @@ def __init__( self.server = self._get_server() self.app = self.server.app + # create a batch request processor loop if using dynamic batching + if self.dynamic_batching: + self.buffer_lock = asyncio.Lock() + self.request_buffer = defaultdict(deque) + + @self.app.on_event("startup") + async def startup_event(): + asyncio.create_task(self._dynamic_batch_processor()) + self.event_loop = asyncio.new_event_loop() asyncio.set_event_loop(self.event_loop) self.event_loop.run_until_complete(self._async_setup()) + async def _dynamic_batch_processor(self): + if logflag: + logger.info("dynamic batch processor looping...") + while True: + await asyncio.sleep(self.dynamic_batching_timeout) + runtime_batch: dict[Enum, list[dict]] = {} # {ServiceType.Embedding: [{"request": xx, "response": yy}, {}]} + + async with self.buffer_lock: + # prepare the runtime batch, access to buffer is locked + if self.request_buffer: + for service_type, request_lst in self.request_buffer.items(): + batch = [] + # grab min(MAX_BATCH_SIZE, REQUEST_SIZE) requests from buffer + for _ in range(min(self.dynamic_batching_max_batch_size, len(request_lst))): + batch.append(request_lst.popleft()) + + runtime_batch[service_type] = batch + + # Run batched inference on the batch and set results + for service_type, batch in runtime_batch.items(): + if not batch: + continue + results = await self.dynamic_batching_infer(service_type, batch) + + for req, result in zip(batch, results): + req["response"].set_result(result) + + async def dynamic_batching_infer(self, service_type: Enum, batch: list[dict]): + """Need to implement.""" + raise NotImplementedError("Unimplemented dynamic batching inference!") + def _validate_env(self): """Check whether to use the microservice locally.""" if self.use_remote_service: @@ -116,10 +169,14 @@ def run(self): self._validate_env() self.event_loop.run_until_complete(self._async_run_forever()) - def start(self): + def start(self, in_single_process=False): self._validate_env() - self.process = multiprocessing.Process(target=self.run, daemon=False, name=self.name) - self.process.start() + if in_single_process: + # Resolve HPU segmentation fault and potential tokenizer issues by limiting to same process + self.run() + else: + self.process = multiprocessing.Process(target=self.run, daemon=False, name=self.name) + self.process.start() async def _async_teardown(self): """Shutdown the server.""" @@ -155,6 +212,9 @@ def register_microservice( provider: Optional[str] = None, provider_endpoint: Optional[str] = None, methods: List[str] = ["POST"], + dynamic_batching: bool = False, + dynamic_batching_timeout: int = 1, + dynamic_batching_max_batch_size: int = 32, ): def decorator(func): if name not in opea_microservices: @@ -172,6 +232,9 @@ def decorator(func): output_datatype=output_datatype, provider=provider, provider_endpoint=provider_endpoint, + dynamic_batching=dynamic_batching, + dynamic_batching_timeout=dynamic_batching_timeout, + dynamic_batching_max_batch_size=dynamic_batching_max_batch_size, ) opea_microservices[name] = micro_service opea_microservices[name].app.router.add_api_route(endpoint, func, methods=methods) diff --git a/comps/embeddings/tei/langchain/Dockerfile.dynamic_batching b/comps/embeddings/tei/langchain/Dockerfile.dynamic_batching new file mode 100644 index 000000000..56148f320 --- /dev/null +++ b/comps/embeddings/tei/langchain/Dockerfile.dynamic_batching @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# FROM opea/habanalabs:1.16.1-pytorch-installer-2.2.2 as hpu +FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest as hpu + +RUN apt-get update -y && apt-get install -y --no-install-recommends --fix-missing \ + libgl1-mesa-glx \ + libjemalloc-dev + +RUN useradd -m -s /bin/bash user && \ + mkdir -p /home/user && \ + chown -R user /home/user/ + +# Disable user for now +# USER user + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir -r /home/user/comps/embeddings/tei/langchain/requirements.txt && \ + pip install git+https://github.com/huggingface/optimum-habana.git + +ENV PYTHONPATH=$PYTHONPATH:/home/user + +WORKDIR /home/user/comps/embeddings/tei/langchain + +ENTRYPOINT ["python", "local_embedding_reranking.py"] diff --git a/comps/embeddings/tei/langchain/local_embedding_reranking.py b/comps/embeddings/tei/langchain/local_embedding_reranking.py new file mode 100644 index 000000000..a29677744 --- /dev/null +++ b/comps/embeddings/tei/langchain/local_embedding_reranking.py @@ -0,0 +1,250 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import math +import os +from enum import Enum +from pathlib import Path +from typing import Union + +import torch +from habana_frameworks.torch.hpu import wrap_in_hpu_graph +from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi +from sentence_transformers.models import Pooling +from transformers import AutoModel, AutoModelForSequenceClassification, AutoTokenizer + +from comps import ( + CustomLogger, + EmbedDoc, + LLMParamsDoc, + SearchedDoc, + ServiceType, + TextDoc, + opea_microservices, + register_microservice, +) +from comps.cores.proto.api_protocol import ChatCompletionRequest, EmbeddingRequest, EmbeddingResponse + +logger = CustomLogger("local_embedding_reranking") +logflag = os.getenv("LOGFLAG", False) + +# keep it consistent for different routers for now +DYNAMIC_BATCHING_TIMEOUT = float(os.getenv("DYNAMIC_BATCHING_TIMEOUT", 0.01)) +DYNAMIC_BATCHING_MAX_BATCH_SIZE = int(os.getenv("DYNAMIC_BATCHING_MAX_BATCH_SIZE", 32)) +PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) +EMBEDDING_MODEL_ID = os.environ.get("EMBEDDING_MODEL_ID", "BAAI/bge-base-en-v1.5") +RERANK_MODEL_ID = os.environ.get("RERANK_MODEL_ID", "BAAI/bge-reranker-base") + + +def round_up(number, k): + return (number + k - 1) // k * k + + +class EmbeddingModel: + def __init__( + self, + model_path: Path, + device: torch.device, + dtype: torch.dtype, + trust_remote: bool = False, + ): + if device == torch.device("hpu"): + adapt_transformers_to_gaudi() + model = AutoModel.from_pretrained(model_path, trust_remote_code=trust_remote).to(dtype).to(device) + if device == torch.device("hpu"): + logger.info("Use graph mode for HPU") + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + self.hidden_size = model.config.hidden_size + self.pooling = Pooling(self.hidden_size, pooling_mode="cls") + self.model = model + + def embed(self, batch): + output = self.model(**batch) + # sentence_embeddings = output[0][:, 0] + # sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + pooling_features = { + "token_embeddings": output[0], + "attention_mask": batch.attention_mask, + } + embedding = self.pooling.forward(pooling_features)["sentence_embedding"] + ## normalize + embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) + cpu_results = embedding.reshape(-1).tolist() + return [cpu_results[i * self.hidden_size : (i + 1) * self.hidden_size] for i in range(len(batch.input_ids))] + + +class RerankingModel: + def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): + if device == torch.device("hpu"): + adapt_transformers_to_gaudi() + + model = AutoModelForSequenceClassification.from_pretrained(model_path) + model = model.to(dtype).to(device) + + if device == torch.device("hpu"): + logger.info("Use graph mode for HPU") + model = wrap_in_hpu_graph(model, disable_tensor_cache=True) + self.model = model + + def predict(self, batch): + scores = ( + self.model(**batch, return_dict=True) + .logits.view( + -1, + ) + .float() + ) + scores = torch.sigmoid(scores) + return scores + + +def pad_batch(inputs: dict, max_input_len: int): + # pad seq_len to MULTIPLE OF, pad bs + batch_size, concrete_length = inputs["input_ids"].size()[0], inputs["input_ids"].size()[1] + max_length = round_up(concrete_length, PAD_SEQUENCE_TO_MULTIPLE_OF) + max_length = min(max_length, max_input_len) # should not exceed max input len + new_bs = 2 ** math.ceil(math.log2(batch_size)) + for x in inputs: + inputs[x] = torch.nn.functional.pad( + inputs[x], (0, max_length - concrete_length, 0, new_bs - batch_size), value=0 + ) + return inputs + + +async def dynamic_batching_infer(service_type: Enum, batch: list[dict]): + if logflag: + logger.info(f"{service_type} {len(batch)} request inference begin >>>") + + if service_type == ServiceType.EMBEDDING: + sentences = [req["request"].text for req in batch] + + with torch.no_grad(): + encoded_input = embedding_tokenizer( + sentences, + padding=True, + truncation=True, + return_tensors="pt", + ).to(device="hpu") + encoded_input = pad_batch(encoded_input, embedding_tokenizer.model_max_length) + # with torch.autocast("hpu", dtype=torch.bfloat16): + results = embedding_model.embed(encoded_input) + + return [EmbedDoc(text=txt, embedding=embed_vector) for txt, embed_vector in zip(sentences, results)] + elif service_type == ServiceType.RERANK: + pairs = [] + doc_lengths = [] + for req in batch: + doc_len = len(req["request"].retrieved_docs) + doc_lengths.append(doc_len) + for idx in range(doc_len): + pairs.append([req["request"].initial_query, req["request"].retrieved_docs[idx].text]) + + with torch.no_grad(): + inputs = reranking_tokenizer( + pairs, + padding=True, + truncation=True, + return_tensors="pt", + ).to("hpu") + inputs = pad_batch(inputs, reranking_tokenizer.model_max_length) + scores = reranking_model.predict(inputs) + + # reduce each query's best related doc + final_results = [] + start = 0 + for idx, doc_len in enumerate(doc_lengths): + req_scores = scores[start : start + doc_len] + cur_req = batch[idx]["request"] + docs: list[TextDoc] = cur_req.retrieved_docs[0:doc_len] + docs = [doc.text for doc in docs] + # sort and select top n docs + top_n_docs = sorted(list(zip(docs, req_scores)), key=lambda x: x[1], reverse=True)[: cur_req.top_n] + top_n_docs: list[str] = [tupl[0] for tupl in top_n_docs] + final_results.append(LLMParamsDoc(query=cur_req.initial_query, documents=top_n_docs)) + + start += doc_len + + return final_results + + +@register_microservice( + name="opea_service@local_embedding_reranking", + service_type=ServiceType.EMBEDDING, + endpoint="/v1/embeddings", + host="0.0.0.0", + port=6001, + dynamic_batching=True, + dynamic_batching_timeout=DYNAMIC_BATCHING_TIMEOUT, + dynamic_batching_max_batch_size=DYNAMIC_BATCHING_MAX_BATCH_SIZE, +) +async def embedding( + input: Union[TextDoc, EmbeddingRequest, ChatCompletionRequest] +) -> Union[EmbedDoc, EmbeddingResponse, ChatCompletionRequest]: + + # if logflag: + # logger.info(input) + # Create a future for this specific request + response_future = asyncio.get_event_loop().create_future() + + cur_microservice = opea_microservices["opea_service@local_embedding_reranking"] + cur_microservice.dynamic_batching_infer = dynamic_batching_infer + async with cur_microservice.buffer_lock: + cur_microservice.request_buffer[ServiceType.EMBEDDING].append({"request": input, "response": response_future}) + + # Wait for batch inference to complete and return results + result = await response_future + + return result + + +@register_microservice( + name="opea_service@local_embedding_reranking", + service_type=ServiceType.RERANK, + endpoint="/v1/reranking", + host="0.0.0.0", + port=6001, + input_datatype=SearchedDoc, + output_datatype=LLMParamsDoc, + dynamic_batching=True, + dynamic_batching_timeout=DYNAMIC_BATCHING_TIMEOUT, + dynamic_batching_max_batch_size=DYNAMIC_BATCHING_MAX_BATCH_SIZE, +) +async def reranking(input: SearchedDoc) -> LLMParamsDoc: + + # if logflag: + # logger.info(input) + + if len(input.retrieved_docs) == 0: + return LLMParamsDoc(query=input.initial_query) + + # Create a future for this specific request + response_future = asyncio.get_event_loop().create_future() + + cur_microservice = opea_microservices["opea_service@local_embedding_reranking"] + cur_microservice.dynamic_batching_infer = dynamic_batching_infer + async with cur_microservice.buffer_lock: + cur_microservice.request_buffer[ServiceType.RERANK].append({"request": input, "response": response_future}) + + # Wait for batch inference to complete and return results + result = await response_future + + return result + + +if __name__ == "__main__": + embedding_model = EmbeddingModel(model_path=EMBEDDING_MODEL_ID, device=torch.device("hpu"), dtype=torch.bfloat16) + embedding_tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_ID) + # sentences = ["sample-1", "sample-2"] + # encoded_input = embedding_tokenizer(sentences, padding=True, truncation=True, return_tensors='pt').to(device="hpu") + # results = embedding_model.embed(encoded_input) + # print(results) + reranking_model = RerankingModel(model_path=RERANK_MODEL_ID, device=torch.device("hpu"), dtype=torch.bfloat16) + reranking_tokenizer = AutoTokenizer.from_pretrained(RERANK_MODEL_ID) + + # pairs = [['what is panda?', 'hi'], ['what is panda?', 'The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.']] + # with torch.no_grad(): + # inputs = reranking_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512).to("hpu") + # scores = reranking_model.predict(inputs) + # print(scores) + opea_microservices["opea_service@local_embedding_reranking"].start(in_single_process=True) diff --git a/tests/cores/mega/test_dynamic_batching.py b/tests/cores/mega/test_dynamic_batching.py new file mode 100644 index 000000000..945054fb0 --- /dev/null +++ b/tests/cores/mega/test_dynamic_batching.py @@ -0,0 +1,91 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import asyncio +import unittest +from enum import Enum + +import aiohttp + +from comps import ServiceType, TextDoc, opea_microservices, register_microservice + + +async def dynamic_batching_infer(service_type: Enum, batch: list[dict]): + # simulate batch inference time + asyncio.sleep(5) + assert len(batch) == 2 + return [{"result": "processed: " + i["request"].text} for i in batch] + + +@register_microservice( + name="s1", + host="0.0.0.0", + port=8080, + endpoint="/v1/add1", + dynamic_batching=True, + dynamic_batching_timeout=2, + dynamic_batching_max_batch_size=32, +) +async def add(request: TextDoc) -> dict: + response_future = asyncio.get_event_loop().create_future() + + cur_microservice = opea_microservices["s1"] + cur_microservice.dynamic_batching_infer = dynamic_batching_infer + + async with cur_microservice.buffer_lock: + cur_microservice.request_buffer[ServiceType.EMBEDDING].append({"request": request, "response": response_future}) + result = await response_future + return result + + +@register_microservice( + name="s1", + host="0.0.0.0", + port=8080, + endpoint="/v1/add2", + dynamic_batching=True, + dynamic_batching_timeout=3, + dynamic_batching_max_batch_size=32, +) +async def add2(request: TextDoc) -> dict: + response_future = asyncio.get_event_loop().create_future() + + cur_microservice = opea_microservices["s1"] + cur_microservice.dynamic_batching_infer = dynamic_batching_infer + + async with cur_microservice.buffer_lock: + cur_microservice.request_buffer[ServiceType.EMBEDDING].append({"request": request, "response": response_future}) + result = await response_future + return result + + +async def fetch(session, url, data): + async with session.post(url, json=data) as response: + # Await the response and return the JSON data + return await response.json() + + +class TestMicroService(unittest.IsolatedAsyncioTestCase): + def setUp(self): + opea_microservices["s1"].start() + + def tearDown(self): + opea_microservices["s1"].stop() + + async def test_dynamic_batching(self): + url1 = "http://localhost:8080/v1/add1" + url2 = "http://localhost:8080/v1/add2" + + # Data for the requests + data1 = {"text": "Hello, "} + data2 = {"text": "OPEA Project!"} + + async with aiohttp.ClientSession() as session: + response1, response2 = await asyncio.gather(fetch(session, url1, data1), fetch(session, url2, data2)) + + self.assertEqual(response1["result"], "processed: Hello, ") + self.assertEqual(response2["result"], "processed: OPEA Project!") + + +if __name__ == "__main__": + unittest.main()