diff --git a/comps/__init__.py b/comps/__init__.py index 85c8456c0..a828cc642 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -45,6 +45,7 @@ AudioQnAGateway, RetrievalToolGateway, FaqGenGateway, + VideoRAGQnAGateway, VisualQnAGateway, MultimodalRAGWithVideosGateway, ) diff --git a/comps/cores/mega/constants.py b/comps/cores/mega/constants.py index b95a56b08..a23fdaf55 100644 --- a/comps/cores/mega/constants.py +++ b/comps/cores/mega/constants.py @@ -38,6 +38,7 @@ class MegaServiceEndpoint(Enum): CHAT_QNA = "/v1/chatqna" AUDIO_QNA = "/v1/audioqna" VISUAL_QNA = "/v1/visualqna" + VIDEO_RAG_QNA = "/v1/videoragqna" CODE_GEN = "/v1/codegen" CODE_TRANS = "/v1/codetrans" DOC_SUMMARY = "/v1/docsum" diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index 97f169c42..2958a245f 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -548,6 +548,55 @@ async def handle_request(self, request: Request): return ChatCompletionResponse(model="visualqna", choices=choices, usage=usage) +class VideoRAGQnAGateway(Gateway): + def __init__(self, megaservice, host="0.0.0.0", port=8888): + super().__init__( + megaservice, + host, + port, + str(MegaServiceEndpoint.VIDEO_RAG_QNA), + ChatCompletionRequest, + ChatCompletionResponse, + ) + + async def handle_request(self, request: Request): + data = await request.json() + stream_opt = data.get("stream", False) + chat_request = ChatCompletionRequest.parse_obj(data) + prompt = self._handle_message(chat_request.messages) + parameters = LLMParams( + max_new_tokens=chat_request.max_tokens if chat_request.max_tokens else 1024, + top_k=chat_request.top_k if chat_request.top_k else 10, + top_p=chat_request.top_p if chat_request.top_p else 0.95, + temperature=chat_request.temperature if chat_request.temperature else 0.01, + repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03, + streaming=stream_opt, + ) + result_dict, runtime_graph = await self.megaservice.schedule( + initial_inputs={"text": prompt}, llm_parameters=parameters + ) + for node, response in result_dict.items(): + # Here it suppose the last microservice in the megaservice is LVM. + if ( + isinstance(response, StreamingResponse) + and node == list(self.megaservice.services.keys())[-1] + and self.megaservice.services[node].service_type == ServiceType.LVM + ): + return response + last_node = runtime_graph.all_leaves()[-1] + response = result_dict[last_node]["text"] + choices = [] + usage = UsageInfo() + choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=response), + finish_reason="stop", + ) + ) + return ChatCompletionResponse(model="videoragqna", choices=choices, usage=usage) + + class RetrievalToolGateway(Gateway): """embed+retrieve+rerank.""" diff --git a/comps/embeddings/multimodal_clip/embeddings_clip.py b/comps/embeddings/multimodal_clip/embeddings_clip.py index 39db85b6e..f010245dd 100644 --- a/comps/embeddings/multimodal_clip/embeddings_clip.py +++ b/comps/embeddings/multimodal_clip/embeddings_clip.py @@ -27,7 +27,8 @@ def embed_query(self, texts): return text_features def get_embedding_length(self): - return len(self.embed_query("sample_text")) + text_features = self.embed_query("sample_text") + return text_features.shape[1] def get_image_embeddings(self, images): """Input is list of images.""" diff --git a/comps/retrievers/langchain/vdms/retriever_vdms.py b/comps/retrievers/langchain/vdms/retriever_vdms.py index 8dae6f8f7..5eaa29ad6 100644 --- a/comps/retrievers/langchain/vdms/retriever_vdms.py +++ b/comps/retrievers/langchain/vdms/retriever_vdms.py @@ -89,7 +89,7 @@ def retrieve(input: EmbedDoc) -> SearchedMultimodalDoc: # Create vectorstore if use_clip: - embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 4}) + embeddings = vCLIP({"model_name": "openai/clip-vit-base-patch32", "num_frm": 64}) dimensions = embeddings.get_embedding_length() elif tei_embedding_endpoint: embeddings = HuggingFaceEndpointEmbeddings(model=tei_embedding_endpoint, huggingfacehub_api_token=hf_token) diff --git a/comps/retrievers/langchain/vdms/vdms_config.py b/comps/retrievers/langchain/vdms/vdms_config.py index d388add9a..5b6a85213 100644 --- a/comps/retrievers/langchain/vdms/vdms_config.py +++ b/comps/retrievers/langchain/vdms/vdms_config.py @@ -77,4 +77,4 @@ def get_boolean_env_var(var_name, default_value=False): # VDMS_SCHEMA = os.getenv("VDMS_SCHEMA", "vdms_schema.yml") # INDEX_SCHEMA = os.path.join(parent_dir, VDMS_SCHEMA) SEARCH_ENGINE = "FaissFlat" -DISTANCE_STRATEGY = "L2" +DISTANCE_STRATEGY = "IP" diff --git a/tests/cores/mega/test_service_orchestrator_with_videoragqnagateway.py b/tests/cores/mega/test_service_orchestrator_with_videoragqnagateway.py new file mode 100644 index 000000000..a9bdcdb33 --- /dev/null +++ b/tests/cores/mega/test_service_orchestrator_with_videoragqnagateway.py @@ -0,0 +1,80 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from fastapi.responses import StreamingResponse + +from comps import ( + ServiceOrchestrator, + ServiceType, + TextDoc, + VideoRAGQnAGateway, + opea_microservices, + register_microservice, +) +from comps.cores.proto.docarray import LLMParams + + +@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add") +async def s1_add(request: TextDoc) -> TextDoc: + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += "opea " + return {"text": text} + + +@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.LVM) +async def s2_add(request: TextDoc) -> TextDoc: + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + + def streamer(text): + yield f"{text}".encode("utf-8") + for i in range(3): + yield "project!".encode("utf-8") + + return StreamingResponse(streamer(text), media_type="text/event-stream") + + +class TestServiceOrchestrator(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.s1 = opea_microservices["s1"] + self.s2 = opea_microservices["s2"] + self.s1.start() + self.s2.start() + + self.service_builder = ServiceOrchestrator() + + self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]) + self.service_builder.flow_to(self.s1, self.s2) + self.gateway = VideoRAGQnAGateway(self.service_builder, port=9898) + + def tearDown(self): + self.s1.stop() + self.s2.stop() + self.gateway.stop() + + async def test_schedule(self): + result_dict, _ = await self.service_builder.schedule( + initial_inputs={"text": "hello, "}, llm_parameters=LLMParams(streaming=True) + ) + streaming_response = result_dict[self.s2.name] + + if isinstance(streaming_response, StreamingResponse): + content = b"" + async for chunk in streaming_response.body_iterator: + content += chunk + final_text = content.decode("utf-8") + + print("Streamed content from s2: ", final_text) + + expected_result = "hello, opea project!project!project!" + self.assertEqual(final_text, expected_result) + + +if __name__ == "__main__": + unittest.main()