Skip to content

Commit

Permalink
Add dynamic DAG (#317)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Spycsh and pre-commit-ci[bot] authored Jul 17, 2024
1 parent 876ca50 commit f2995ab
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 36 deletions.
36 changes: 22 additions & 14 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict = await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -128,7 +130,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down Expand Up @@ -161,7 +163,9 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -170,7 +174,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down Expand Up @@ -208,7 +212,7 @@ async def handle_request(self, request: Request):
### Translated codes:
"""
prompt = prompt_template.format(language_from=language_from, language_to=language_to, source_code=source_code)
result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt})
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt})
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -217,7 +221,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down Expand Up @@ -253,7 +257,7 @@ async def handle_request(self, request: Request):
prompt = prompt_template.format(
language_from=language_from, language_to=language_to, source_language=source_language
)
result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt})
result_dict, runtime_graph = await self.megaservice.schedule(initial_inputs={"query": prompt})
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -262,7 +266,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down Expand Up @@ -295,7 +299,9 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict = await self.megaservice.schedule(initial_inputs={"query": prompt}, llm_parameters=parameters)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"query": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -304,7 +310,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down Expand Up @@ -342,11 +348,11 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=False, # TODO add streaming LLM output as input to TTS
)
result_dict = await self.megaservice.schedule(
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"byte_str": chat_request.audio}, llm_parameters=parameters
)

last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["byte_str"]

return response
Expand All @@ -371,7 +377,9 @@ async def handle_request(self, request: Request):
repetition_penalty=chat_request.presence_penalty if chat_request.presence_penalty else 1.03,
streaming=stream_opt,
)
result_dict = await self.megaservice.schedule(initial_inputs={"text": prompt}, llm_parameters=parameters)
result_dict, runtime_graph = await self.megaservice.schedule(
initial_inputs={"text": prompt}, llm_parameters=parameters
)
for node, response in result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
Expand All @@ -380,7 +388,7 @@ async def handle_request(self, request: Request):
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
last_node = runtime_graph.all_leaves()[-1]
response = result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
Expand Down
58 changes: 47 additions & 11 deletions comps/cores/mega/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

import asyncio
import copy
import json
import re
from typing import Dict, List

import aiohttp
Expand Down Expand Up @@ -39,10 +41,16 @@ def flow_to(self, from_service, to_service):

async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMParams()):
result_dict = {}
runtime_graph = DAG()
runtime_graph.graph = copy.deepcopy(self.graph)

timeout = aiohttp.ClientTimeout(total=1000)
async with aiohttp.ClientSession(trust_env=True, timeout=timeout) as session:
pending = {asyncio.create_task(self.execute(session, node, initial_inputs)) for node in self.ind_nodes()}
pending = {
asyncio.create_task(self.execute(session, node, initial_inputs, runtime_graph))
for node in self.ind_nodes()
}
ind_nodes = self.ind_nodes()

while pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -51,13 +59,40 @@ async def schedule(self, initial_inputs: Dict, llm_parameters: LLMParams = LLMPa
self.dump_outputs(node, response, result_dict)

# traverse the current node's downstream nodes and execute if all one's predecessors are finished
downstreams = self.downstream(node)
for d_node in downstreams:
if all(i in result_dict for i in self.predecessors(d_node)):
inputs = self.process_outputs(self.predecessors(d_node), result_dict)
pending.add(asyncio.create_task(self.execute(session, d_node, inputs, llm_parameters)))
downstreams = runtime_graph.downstream(node)

# remove all the black nodes that are skipped to be forwarded to
if not isinstance(response, StreamingResponse) and "downstream_black_list" in response:
for black_node in response["downstream_black_list"]:
for downstream in reversed(downstreams):
try:
if re.findall(black_node, downstream):
print(f"skip forwardding to {downstream}...")
runtime_graph.delete_edge(node, downstream)
downstreams.remove(downstream)
except re.error as e:
print("Pattern invalid! Operation cancelled.")

return result_dict
for d_node in downstreams:
if all(i in result_dict for i in runtime_graph.predecessors(d_node)):
inputs = self.process_outputs(runtime_graph.predecessors(d_node), result_dict)
pending.add(
asyncio.create_task(
self.execute(session, d_node, inputs, runtime_graph, llm_parameters)
)
)
nodes_to_keep = []
for i in ind_nodes:
nodes_to_keep.append(i)
nodes_to_keep.extend(runtime_graph.all_downstreams(i))

all_nodes = list(runtime_graph.graph.keys())

for node in all_nodes:
if node not in nodes_to_keep:
runtime_graph.delete_node_if_exists(node)

return result_dict, runtime_graph

def process_outputs(self, prev_nodes: List, result_dict: Dict) -> Dict:
all_outputs = {}
Expand All @@ -72,6 +107,7 @@ async def execute(
session: aiohttp.client.ClientSession,
cur_node: str,
inputs: Dict,
runtime_graph: DAG,
llm_parameters: LLMParams = LLMParams(),
):
# send the cur_node request/reply
Expand All @@ -97,8 +133,8 @@ def generate():
else:
if (
self.services[cur_node].service_type == ServiceType.LLM
and self.predecessors(cur_node)
and "asr" in self.predecessors(cur_node)[0]
and runtime_graph.predecessors(cur_node)
and "asr" in runtime_graph.predecessors(cur_node)[0]
):
inputs["query"] = inputs["text"]
del inputs["text"]
Expand All @@ -109,8 +145,8 @@ def generate():
def dump_outputs(self, node, response, result_dict):
result_dict[node] = response

def get_all_final_outputs(self, result_dict):
def get_all_final_outputs(self, result_dict, runtime_graph):
final_output_dict = {}
for leaf in self.all_leaves():
for leaf in runtime_graph.all_leaves():
final_output_dict[leaf] = result_dict[leaf]
return final_output_dict
8 changes: 7 additions & 1 deletion comps/cores/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@
from pydantic import Field, conint, conlist


class TextDoc(BaseDoc):
class TopologyInfo:
# will not keep forwarding to the downstream nodes in the black list
# should be a pattern string
downstream_black_list: Optional[list] = []


class TextDoc(BaseDoc, TopologyInfo):
text: str


Expand Down
6 changes: 4 additions & 2 deletions comps/guardrails/langchain/guardrails_tgi_gaudi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def safety_guard(input: TextDoc) -> TextDoc:
policy_violation_level = response_input_guard.split("\n")[1].strip()
policy_violations = unsafe_dict[policy_violation_level]
print(f"Violated policies: {policy_violations}")
res = TextDoc(text=f"Violated policies: {policy_violations}, please check your input.")
res = TextDoc(
text=f"Violated policies: {policy_violations}, please check your input.", downstream_black_list=[".*"]
)
else:
res = TextDoc(text="safe")
res = TextDoc(text=input.text)

return res

Expand Down
6 changes: 3 additions & 3 deletions tests/cores/mega/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ async def test_schedule(self):
task2 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hi, "}))
await asyncio.gather(task1, task2)

result_dict1 = task1.result()
result_dict2 = task2.result()
result_dict1, runtime_graph1 = task1.result()
result_dict2, runtime_graph2 = task2.result()
self.assertEqual(result_dict1[self.s2.name]["text"], "hello, opea project1!")
self.assertEqual(result_dict1[self.s3.name]["text"], "hello, opea project2!")
self.assertEqual(result_dict2[self.s2.name]["text"], "hi, opea project1!")
self.assertEqual(result_dict2[self.s3.name]["text"], "hi, opea project2!")
self.assertEqual(len(self.service_builder.get_all_final_outputs(result_dict1).keys()), 2)
self.assertEqual(len(self.service_builder.get_all_final_outputs(result_dict1, runtime_graph1).keys()), 2)
self.assertEqual(int(time.time() - t), 15)


Expand Down
2 changes: 1 addition & 1 deletion tests/cores/mega/test_base_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def test_base_statistics(self):
for _ in range(2):
task1 = asyncio.create_task(self.service_builder.schedule(initial_inputs={"text": "hello, "}))
await asyncio.gather(task1)
result_dict1 = task1.result()
result_dict1, _ = task1.result()

response = requests.get("http://localhost:8083/v1/statistics")
res = response.json()
Expand Down
92 changes: 92 additions & 0 deletions tests/cores/mega/test_runtime_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import unittest

from fastapi.testclient import TestClient

from comps import ServiceOrchestrator, TextDoc, opea_microservices, register_microservice


@register_microservice(name="s1", host="0.0.0.0", port=8080, endpoint="/v1/add")
async def add_s1(request: TextDoc) -> TextDoc:
text = request.text
if "Hi" in text:
text += "OPEA Project!"
return TextDoc(text=text, downstream_black_list=[])
elif "Bye" in text:
text += "OPEA Project!"
return TextDoc(text=text, downstream_black_list=[".*"])
elif "Hola" in text:
text += "OPEA Project!"
return TextDoc(text=text, downstream_black_list=["s2"])
else:
text += "OPEA Project!"
return TextDoc(text=text, downstream_black_list=["s3"])


@register_microservice(name="s2", host="0.0.0.0", port=8081, endpoint="/v1/add")
async def add_s2(request: TextDoc) -> TextDoc:
text = request.text
text += "add s2!"
return TextDoc(text=text)


@register_microservice(name="s3", host="0.0.0.0", port=8082, endpoint="/v1/add")
async def add_s3(request: TextDoc) -> TextDoc:
text = request.text
text += "add s3!"
return TextDoc(text=text)


@register_microservice(name="s4", host="0.0.0.0", port=8083, endpoint="/v1/add")
async def add_s4(request: TextDoc) -> TextDoc:
text = request.text
text += "add s4!"
return TextDoc(text=text)


class TestMicroService(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.client1 = TestClient(opea_microservices["s1"].app)
self.s1 = opea_microservices["s1"]
self.s2 = opea_microservices["s2"]
self.s3 = opea_microservices["s3"]
self.s4 = opea_microservices["s4"]

self.s1.start()
self.s2.start()
self.s3.start()
self.s4.start()

self.service_builder = ServiceOrchestrator()
self.service_builder.add(self.s1).add(self.s2).add(self.s3).add(self.s4)
self.service_builder.flow_to(self.s1, self.s2)
self.service_builder.flow_to(self.s1, self.s3)
self.service_builder.flow_to(self.s3, self.s4)

def tearDown(self):
self.s1.stop()
self.s2.stop()
self.s3.stop()
self.s4.stop()

async def test_add_route(self):
result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Hi!"})
assert len(result_dict) == 4
assert len(runtime_graph.all_leaves()) == 2
result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Bye!"})
assert len(result_dict) == 1
assert len(runtime_graph.all_leaves()) == 1
result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Hola!"})
assert len(result_dict) == 3
assert len(runtime_graph.all_leaves()) == 1
result_dict, runtime_graph = await self.service_builder.schedule(initial_inputs={"text": "Other!"})
print(runtime_graph.graph)
assert len(result_dict) == 2
assert len(runtime_graph.all_leaves()) == 1


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion tests/cores/mega/test_service_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def tearDown(self):
self.s2.stop()

async def test_schedule(self):
result_dict = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
result_dict, _ = await self.service_builder.schedule(initial_inputs={"text": "hello, "})
self.assertEqual(result_dict[self.s2.name]["text"], "hello, opea project!")


Expand Down
Loading

0 comments on commit f2995ab

Please sign in to comment.