Skip to content

Commit

Permalink
Optimize mega flow by removing microservice wrapper (#582)
Browse files Browse the repository at this point in the history
* refactor orchestrator

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove no_wrapper

* fix

* fix

* add align_gen

* add retriever and rerank params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add fake test for customize params

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix dep

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] committed Sep 4, 2024
1 parent 3367b76 commit 0bb69ac
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 15 deletions.
18 changes: 16 additions & 2 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
EmbeddingRequest,
UsageInfo,
)
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, TextDoc
from ..proto.docarray import LLMParams, LLMParamsDoc, RerankedDoc, RerankerParms, RetrieverParms, TextDoc
from .constants import MegaServiceEndpoint, ServiceRoleType, ServiceType
from .micro_service import MicroService

Expand Down Expand Up @@ -167,8 +167,22 @@ async def handle_request(self, request: Request):
streaming=stream_opt,
chat_template=chat_request.chat_template if chat_request.chat_template else None,
)
retriever_parameters = RetrieverParms(
search_type=chat_request.search_type if chat_request.search_type else "similarity",
k=chat_request.k if chat_request.k else 4,
distance_threshold=chat_request.distance_threshold if chat_request.distance_threshold else None,
fetch_k=chat_request.fetch_k if chat_request.fetch_k else 20,
lambda_mult=chat_request.lambda_mult if chat_request.lambda_mult else 0.5,
score_threshold=chat_request.score_threshold if chat_request.score_threshold else 0.2,
)
reranker_parameters = RerankerParms(
top_n=chat_request.top_n if chat_request.top_n else 1,
)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
initial_inputs={"text": prompt},
llm_parameters=parameters,
retriever_parameters=retriever_parameters,
reranker_parameters=reranker_parameters,
)
for node, response in result_dict.items():
if isinstance(response, StreamingResponse):
Expand Down
69 changes: 56 additions & 13 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import asyncio
import copy
import json
import os
import re
from typing import Dict, List

Expand All @@ -14,6 +15,10 @@
from ..proto.docarray import LLMParams
from .constants import ServiceType
from .dag import DAG
from .logger import CustomLogger

logger = CustomLogger("comps-core-orchestrator")
LOGFLAG = os.getenv("LOGFLAG", False)


class ServiceOrchestrator(DAG):
Expand All @@ -36,18 +41,22 @@ def flow_to(self, from_service, to_service):
self.add_edge(from_service.name, to_service.name)
return True
except Exception as e:
print(e)
logger.error(e)
return False

async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()):
async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams(), **kwargs):
result_dict = {}
runtime_graph = DAG()
runtime_graph.graph = copy.deepcopy(self.graph)
if LOGFLAG:
logger.info(initial_inputs)

timeout = aiohttp.ClientTimeout(total=1000)
async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session:
pending = {
asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph, llm_parameters))
asyncio.create_task(
self.execute(session, node, initial_inputs, runtime_graph, llm_parameters, **kwargs)
)
for node in self.ind_nodes()
}
ind_nodes = self.ind_nodes()
Expand All @@ -67,11 +76,12 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa
for downstream in reversed(downstreams):
try:
if re.findall(black_node, downstream):
print(f"skip forwardding to {downstream}...")
if LOGFLAG:
logger.info(f"skip forwardding to {downstream}...")
runtime_graph.delete_edge(node, downstream)
downstreams.remove(downstream)
except re.error as e:
print("Pattern invalid! Operation cancelled.")
logger.error("Pattern invalid! Operation cancelled.")
if len(downstreams) == 0 and llm_parameters.streaming:
# turn the response to a StreamingResponse
# to make the response uniform to UI
Expand All @@ -90,7 +100,7 @@ def fake_stream(text):
inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict)
pending.add(
asyncio.create_task(
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
self.execute(session, d_node, inputs, runtime_graph, llm_parameters, **kwargs)
)
)
nodes_to_keep = []
Expand Down Expand Up @@ -121,21 +131,33 @@ async def execute(
inputs: Dict,
runtime_graph: DAG,
llm_parameters: LLMParams = LLMParams(),
**kwargs,
):
# send the cur_node request/reply
endpoint = self.services[cur_node].endpoint_path
llm_parameters_dict = llm_parameters.dict()
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value
if self.services[cur_node].service_type == ServiceType.LLM:
for field, value in llm_parameters_dict.items():
if inputs.get(field) != value:
inputs[field] = value

# pre-process
inputs = self.align_inputs(inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs)

if (
self.services[cur_node].service_type == ServiceType.LLM
or self.services[cur_node].service_type == ServiceType.LVM
) and llm_parameters.streaming:
# Still leave to sync requests.post for StreamingResponse
if LOGFLAG:
logger.info(inputs)
response = requests.post(
url=endpoint, data=json.dumps(inputs), proxies={"http": None}, stream=True, timeout=1000
url=endpoint,
data=json.dumps(inputs),
headers={"Content-type": "application/json"},
proxies={"http": None},
stream=True,
timeout=1000,
)
downstream = runtime_graph.downstream(cur_node)
if downstream:
Expand Down Expand Up @@ -169,11 +191,32 @@ def generate():
else:
yield chunk

return StreamingResponse(generate(), media_type="text/event-stream"), cur_node
return (
StreamingResponse(self.align_generator(generate(), **kwargs), media_type="text/event-stream"),
cur_node,
)
else:
if LOGFLAG:
logger.info(inputs)
async with session.post(endpoint, json=inputs) as response:
print(f"{cur_node}: {response.status}")
return await response.json(), cur_node
# Parse as JSON
data = await response.json()
# post process
data = self.align_outputs(data, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs)

return data, cur_node

def align_inputs(self, inputs, *args, **kwargs):
"""Override this method in megaservice definition."""
return inputs

def align_outputs(self, data, *args, **kwargs):
"""Override this method in megaservice definition."""
return data

def align_generator(self, gen, *args, **kwargs):
"""Override this method in megaservice definition."""
return gen

def dump_outputs(self, node, response, result_dict):
result_dict[node] = response
Expand Down
13 changes: 13 additions & 0 deletions comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,19 @@ class LLMParams(BaseDoc):
)


class RetrieverParms(BaseDoc):
search_type: str = "similarity"
k: int = 4
distance_threshold: Optional[float] = None
fetch_k: int = 20
lambda_mult: float = 0.5
score_threshold: float = 0.2


class RerankerParms(BaseDoc):
top_n: int = 1


class RAGASParams(BaseDoc):
questions: DocList[TextDoc]
answers: DocList[TextDoc]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from comps import (
EmbedDoc,
Gateway,
RerankedDoc,
ServiceOrchestrator,
TextDoc,
opea_microservices,
register_microservice,
)
from comps.cores.mega.constants import ServiceType
from comps.cores.proto.docarray import LLMParams, RerankerParms, RetrieverParms


@register_microservice(name="s1", host="0.0.0.0", port=8083, endpoint="/v1/add", service_type=ServiceType.RETRIEVER)
async def s1_add(request: EmbedDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += f"opea top_k {req_dict['k']}"
return {"text": text}


@register_microservice(name="s2", host="0.0.0.0", port=8084, endpoint="/v1/add", service_type=ServiceType.RERANK)
async def s2_add(request: TextDoc) -> TextDoc:
req = request.model_dump_json()
req_dict = json.loads(req)
text = req_dict["text"]
text += "project!"
return {"text": text}


def align_inputs(self, inputs, cur_node, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.RETRIEVER:
inputs["k"] = kwargs["retriever_parameters"].k

return inputs


def align_outputs(self, outputs, cur_node, inputs, runtime_graph, llm_parameters_dict, **kwargs):
if self.services[cur_node].service_type == ServiceType.RERANK:
top_n = kwargs["reranker_parameters"].top_n
outputs["text"] = outputs["text"][:top_n]
return outputs


class TestServiceOrchestratorParams(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s1.start()
self.s2.start()

ServiceOrchestrator.align_inputs = align_inputs
ServiceOrchestrator.align_outputs = align_outputs
self.service_builder = ServiceOrchestrator()

self.service_builder.add(opea_microservices["s1"]).add(opea_microservices["s2"])
self.service_builder.flow_to(self.s1, self.s2)
self.gateway = Gateway(self.service_builder, port=9898)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.gateway.stop()

async def test_retriever_schedule(self):
result_dict, _ = await self.service_builder.schedule(
initial_inputs={"text": "hello, ", "embedding": [1.0, 2.0, 3.0]},
retriever_parameters=RetrieverParms(k=8),
reranker_parameters=RerankerParms(top_n=20),
)
self.assertEqual(len(result_dict[self.s2.name]["text"]), 20) # Check reranker top_n is accessed
self.assertTrue("8" in result_dict[self.s2.name]["text"]) # Check retriever k is accessed


if __name__ == "__main__":
unittest.main()

0 comments on commit 0bb69ac

Please sign in to comment.