From 0bb69ac17601aa4811caeaa0050b367191672a98 Mon Sep 17 00:00:00 2001 From: Sihan Chen <39623753+Spycsh@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:49:49 +0800 Subject: [PATCH] Optimize mega flow by removing microservice wrapper (#582) * refactor orchestrator * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove no_wrapper * fix * fix * add align_gen * add retriever and rerank params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add fake test for customize params * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix dep --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- comps/cores/mega/gateway.py | 18 +++- comps/cores/mega/orchestrator.py | 69 ++++++++++++--- comps/cores/proto/docarray.py | 13 +++ ...orchestrator_with_retriever_rerank_fake.py | 83 +++++++++++++++++++ 4 files changed, 168 insertions(+), 15 deletions(-) create mode 100644 tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py diff --git a/comps/cores/mega/gateway.py b/comps/cores/mega/gateway.py index cc8eaf5d2..5191d1d9a 100644 --- a/comps/cores/mega/gateway.py +++ b/comps/cores/mega/gateway.py @@ -20,7 +20,7 @@ EmbeddingRequest, UsageInfo, ) -from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc +from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType from .micro_service import MicroService @@ -167,8 +167,22 @@ async def handle_request(self, request: Request): streaming=stream_opt, chat_template=chat_request.chat_template if chat_request.chat_template else None, ) + retriever_parameters = RetrieverParms( + search_type=chat_request.search_type if chat_request.search_type else "similarity", + k=chat_request.k if chat_request.k else 4, + distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None, + fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20, + lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5, + score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2, + ) + reranker_parameters = RerankerParms( + top_n=chat_request.top_n if chat_request.top_n else 1, + ) result_dict, runtime_graph = await self.megaservice.schedule( - initial_inputs={"text": prompt}, llm_parameters=parameters + initial_inputs={"text": prompt}, + llm_parameters=parameters, + retriever_parameters=retriever_parameters, + reranker_parameters=reranker_parameters, ) for node, response in result_dict.items(): if isinstance(response, StreamingResponse): diff --git a/comps/cores/mega/orchestrator.py b/comps/cores/mega/orchestrator.py index 92063d498..2410140b8 100644 --- a/comps/cores/mega/orchestrator.py +++ b/comps/cores/mega/orchestrator.py @@ -4,6 +4,7 @@ import asyncio import copy import json +import os import re from typing import Dict, List @@ -14,6 +15,10 @@ from ..proto.docarray import LLMParams from .constants import ServiceType from .dag import DAG +from .logger import CustomLogger + +logger = CustomLogger("comps-core-orchestrator") +LOGFLAG = os.getenv("LOGFLAG", False) class ServiceOrchestrator(DAG): @@ -36,18 +41,22 @@ def flow_to(self, from_service, to_service): self.add_edge(from_service.name, to_service.name) return True except Exception as e: - print(e) + logger.error(e) return False - async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()): + async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs): result_dict = {} runtime_graph = DAG() runtime_graph.graph = copy.deepcopy(self.graph) + if LOGFLAG: + logger.info(initial_inputs) timeout = aiohttp.ClientTimeout(total=1000) async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session: pending = { - asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph, llm_parameters)) + asyncio.create_task( + self.execute(session, node, initial_inputs, runtime_graph, llm_parameters, **kwargs) + ) for node in self.ind_nodes() } ind_nodes = self.ind_nodes() @@ -67,11 +76,12 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa for downstream in reversed(downstreams): try: if re.findall(black_node, downstream): - print(f"skip forwardding to {downstream}...") + if LOGFLAG: + logger.info(f"skip forwardding to {downstream}...") runtime_graph.delete_edge(node, downstream) downstreams.remove(downstream) except re.error as e: - print("Pattern invalid! Operation cancelled.") + logger.error("Pattern invalid! Operation cancelled.") if len(downstreams) == 0 and llm_parameters.streaming: # turn the response to a StreamingResponse # to make the response uniform to UI @@ -90,7 +100,7 @@ def fake_stream(text): inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict) pending.add( asyncio.create_task( - self.execute(session, d_node, inputs, runtime_graph, llm_parameters) + self.execute(session, d_node, inputs, runtime_graph, llm_parameters, **kwargs) ) ) nodes_to_keep = [] @@ -121,21 +131,33 @@ async def execute( inputs: Dict, runtime_graph: DAG, llm_parameters: LLMParams = LLMParams(), + **kwargs, ): # send the cur_node request/reply endpoint = self.services[cur_node].endpoint_path llm_parameters_dict = llm_parameters.dict() - for field, value in llm_parameters_dict.items(): - if inputs.get(field) != value: - inputs[field] = value + if self.services[cur_node].service_type == ServiceType.LLM: + for field, value in llm_parameters_dict.items(): + if inputs.get(field) != value: + inputs[field] = value + + # pre-process + inputs = self.align_inputs(inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs) if ( self.services[cur_node].service_type == ServiceType.LLM or self.services[cur_node].service_type == ServiceType.LVM ) and llm_parameters.streaming: # Still leave to sync requests.post for StreamingResponse + if LOGFLAG: + logger.info(inputs) response = requests.post( - url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000 + url=endpoint, + data=json.dumps(inputs), + headers={"Content-type": "application/json"}, + proxies={"http": None}, + stream=True, + timeout=1000, ) downstream = runtime_graph.downstream(cur_node) if downstream: @@ -169,11 +191,32 @@ def generate(): else: yield chunk - return StreamingResponse(generate(), media_type="text/event-stream"), cur_node + return ( + StreamingResponse(self.align_generator(generate(), **kwargs), media_type="text/event-stream"), + cur_node, + ) else: + if LOGFLAG: + logger.info(inputs) async with session.post(endpoint, json=inputs) as response: - print(f"{cur_node}: {response.status}") - return await response.json(), cur_node + # Parse as JSON + data = await response.json() + # post process + data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs) + + return data, cur_node + + def align_inputs(self, inputs, *args, **kwargs): + """Override this method in megaservice definition.""" + return inputs + + def align_outputs(self, data, *args, **kwargs): + """Override this method in megaservice definition.""" + return data + + def align_generator(self, gen, *args, **kwargs): + """Override this method in megaservice definition.""" + return gen def dump_outputs(self, node, response, result_dict): result_dict[node] = response diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 587278ec1..132b172bc 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -173,6 +173,19 @@ class LLMParams(BaseDoc): ) +class RetrieverParms(BaseDoc): + search_type: str = "similarity" + k: int = 4 + distance_threshold: Optional[float] = None + fetch_k: int = 20 + lambda_mult: float = 0.5 + score_threshold: float = 0.2 + + +class RerankerParms(BaseDoc): + top_n: int = 1 + + class RAGASParams(BaseDoc): questions: DocList[TextDoc] answers: DocList[TextDoc] diff --git a/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py b/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py new file mode 100644 index 000000000..8ee7116a7 --- /dev/null +++ b/tests/cores/mega/test_service_orchestrator_with_retriever_rerank_fake.py @@ -0,0 +1,83 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import json +import unittest + +from comps import ( + EmbedDoc, + Gateway, + RerankedDoc, + ServiceOrchestrator, + TextDoc, + opea_microservices, + register_microservice, +) +from comps.cores.mega.constants import ServiceType +from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms + + +@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add", service_type=ServiceType.RETRIEVER) +async def s1_add(request: EmbedDoc) -> TextDoc: + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += f"opea top_k {req_dict['k']}" + return {"text": text} + + +@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.RERANK) +async def s2_add(request: TextDoc) -> TextDoc: + req = request.model_dump_json() + req_dict = json.loads(req) + text = req_dict["text"] + text += "project!" + return {"text": text} + + +def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs): + if self.services[cur_node].service_type == ServiceType.RETRIEVER: + inputs["k"] = kwargs["retriever_parameters"].k + + return inputs + + +def align_outputs(self, outputs, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs): + if self.services[cur_node].service_type == ServiceType.RERANK: + top_n = kwargs["reranker_parameters"].top_n + outputs["text"] = outputs["text"][:top_n] + return outputs + + +class TestServiceOrchestratorParams(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.s1 = opea_microservices["s1"] + self.s2 = opea_microservices["s2"] + self.s1.start() + self.s2.start() + + ServiceOrchestrator.align_inputs = align_inputs + ServiceOrchestrator.align_outputs = align_outputs + self.service_builder = ServiceOrchestrator() + + self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"]) + self.service_builder.flow_to(self.s1, self.s2) + self.gateway = Gateway(self.service_builder, port=9898) + + def tearDown(self): + self.s1.stop() + self.s2.stop() + self.gateway.stop() + + async def test_retriever_schedule(self): + result_dict, _ = await self.service_builder.schedule( + initial_inputs={"text": "hello, ", "embedding": [1.0, 2.0, 3.0]}, + retriever_parameters=RetrieverParms(k=8), + reranker_parameters=RerankerParms(top_n=20), + ) + self.assertEqual(len(result_dict[self.s2.name]["text"]), 20) # Check reranker top_n is accessed + self.assertTrue("8" in result_dict[self.s2.name]["text"]) # Check retriever k is accessed + + +if __name__ == "__main__": + unittest.main()