Skip to content

Commit

Permalink
Add Megaservice support for MMRAG VideoRAGQnA usecase (#603)
Browse files Browse the repository at this point in the history
* add videoragqna gateway

Signed-off-by: BaoHuiling <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add test script for gateway

Signed-off-by: BaoHuiling <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm ip

Signed-off-by: BaoHuiling <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix exist bug

Signed-off-by: BaoHuiling <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: BaoHuiling <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: kevinintel <[email protected]>
  • Loading branch information
3 people authored Sep 9, 2024
1 parent 23cc3ea commit 2c48bc8
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 3 deletions.
1 change: 1 addition & 0 deletions comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
AudioQnAGateway,
RetrievalToolGateway,
FaqGenGateway,
VideoRAGQnAGateway,
VisualQnAGateway,
MultimodalRAGWithVideosGateway,
)
Expand Down
1 change: 1 addition & 0 deletions comps/cores/mega/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MegaServiceEndpoint(Enum):
CHAT_QNA = "/v1/chatqna"
AUDIO_QNA = "/v1/audioqna"
VISUAL_QNA = "/v1/visualqna"
VIDEO_RAG_QNA = "/v1/videoragqna"
CODE_GEN = "/v1/codegen"
CODE_TRANS = "/v1/codetrans"
DOC_SUMMARY = "/v1/docsum"
Expand Down
49 changes: 49 additions & 0 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,55 @@ async def handle_request(self, request: Request):
return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage)


class VideoRAGQnAGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice,
host,
port,
str(MegaServiceEndpoint.VIDEO_RAG_QNA),
ChatCompletionRequest,
ChatCompletionResponse,
)

async def handle_request(self, request: Request):
data = await request.json()
stream_opt = data.get("stream", False)
chat_request = ChatCompletionRequest.parse_obj(data)
prompt = self._handle_message(chat_request.messages)
parameters = LLMParams(
max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024,
top_k=chat_request.top_k if chat_request.top_k else 10,
top_p=chat_request.top_p if chat_request.top_p else 0.95,
temperature=chat_request.temperature if chat_request.temperature else 0.01,
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LVM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LVM
):
return response
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="videoragqna", choices=choices, usage=usage)


class RetrievalToolGateway(Gateway):
"""embed+retrieve+rerank."""

Expand Down
3 changes: 2 additions & 1 deletion comps/embeddings/multimodal_clip/embeddings_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def embed_query(self, texts):
return text_features

def get_embedding_length(self):
return len(self.embed_query("sample_text"))
text_features = self.embed_query("sample_text")
return text_features.shape[1]

def get_image_embeddings(self, images):
"""Input is list of images."""
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/langchain/vdms/retriever_vdms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def retrieve(input: EmbedDoc) -> SearchedMultimodalDoc:
# Create vectorstore

if use_clip:
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 4})
embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64})
dimensions = embeddings.get_embedding_length()
elif tei_embedding_endpoint:
embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint, huggingfacehub_api_token=hf_token)
Expand Down
2 changes: 1 addition & 1 deletion comps/retrievers/langchain/vdms/vdms_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def get_boolean_env_var(var_name, default_value=False):
# VDMS_SCHEMA = os.getenv("VDMS_SCHEMA", "vdms_schema.yml")
# INDEX_SCHEMA = os.path.join(parent_dir, VDMS_SCHEMA)
SEARCH_ENGINE = "FaissFlat"
DISTANCE_STRATEGY = "L2"
DISTANCE_STRATEGY = "IP"
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from fastapi.responses import StreamingResponse

from comps import (
ServiceOrchestrator,
ServiceType,
TextDoc,
VideoRAGQnAGateway,
opea_microservices,
register_microservice,
)
from comps.cores.proto.docarray import LLMParams


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def s1_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "opea "
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.LVM)
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]

def streamer(text):
yield f"{text}".encode("utf-8")
for i in range(3):
yield "project!".encode("utf-8")

return StreamingResponse(streamer(text), media_type="text/event-stream")


class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = VideoRAGQnAGateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

async def test_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, "}, llm_parameters=LLMParams(streaming=True)
)
streaming_response = result_dict[self.s2.name]

if isinstance(streaming_response, StreamingResponse):
content = b""
async for chunk in streaming_response.body_iterator:
content += chunk
final_text = content.decode("utf-8")

print("Streamed content from s2: ", final_text)

expected_result = "hello, opea project!project!project!"
self.assertEqual(final_text, expected_result)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2c48bc8

Please sign in to comment.