Skip to content

Commit

Permalink
[Serve] Support Multiple DAG Entrypoints in DAGDriver (ray-project#26573
Browse files Browse the repository at this point in the history
)

Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
sihanwang41 authored and Stefan van der Kleij committed Aug 18, 2022
1 parent e591624 commit f975ab9
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 104 deletions.
57 changes: 50 additions & 7 deletions python/ray/serve/drivers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import inspect
from abc import abstractmethod
from typing import Any, Callable, Optional, Type, Union
from typing import Any, Callable, Optional, Type, Union, Dict
from pydantic import BaseModel
from ray.serve._private.utils import install_serve_encoders_to_fastapi
from ray.util.annotations import DeveloperAPI, PublicAPI
Expand All @@ -11,6 +12,8 @@
from ray._private.utils import import_attr
from ray.serve.deployment_graph import RayServeDAGHandle
from ray.serve._private.http_util import ASGIHTTPSender
from ray.serve.handle import RayServeLazySyncHandle
from ray.serve.exceptions import RayServeException
from ray import serve

DEFAULT_HTTP_ADAPTER = "ray.serve.http_adapters.starlette_request"
Expand Down Expand Up @@ -83,16 +86,56 @@ async def __call__(self, request: starlette.requests.Request):

@PublicAPI(stability="beta")
@serve.deployment(route_prefix="/")
class DAGDriver(SimpleSchemaIngress):
class DAGDriver:

MATCH_ALL_ROUTE_PREFIX = "/{path:path}"

def __init__(
self,
dag_handle: RayServeDAGHandle,
*,
dags: Union[RayServeDAGHandle, Dict[str, RayServeDAGHandle]],
http_adapter: Optional[Union[str, Callable]] = None,
):
self.dag_handle = dag_handle
super().__init__(http_adapter)
install_serve_encoders_to_fastapi()
http_adapter = _load_http_adapter(http_adapter)
self.app = FastAPI()

if isinstance(dags, dict):
self.dags = dags
for route, handle in dags.items():

def endpoint_create(handle):
@self.app.get(f"{route}")
@self.app.post(f"{route}")
async def handle_request(inp=Depends(http_adapter)):
return await handle.remote(inp)

# bind current handle with endpoint creation function
endpoint_create_func = functools.partial(endpoint_create, handle)
endpoint_create_func()

else:
assert isinstance(dags, (RayServeDAGHandle, RayServeLazySyncHandle))
self.dags = {self.MATCH_ALL_ROUTE_PREFIX: dags}

# Single dag case, we will receive all prefix route
@self.app.get(self.MATCH_ALL_ROUTE_PREFIX)
@self.app.post(self.MATCH_ALL_ROUTE_PREFIX)
async def handle_request(inp=Depends(http_adapter)):
return await self.predict(inp)

async def __call__(self, request: starlette.requests.Request):
# NOTE(simon): This is now duplicated from ASGIAppWrapper because we need to
# generate FastAPI on the fly, we should find a way to unify the two.
sender = ASGIHTTPSender()
await self.app(request.scope, receive=request.receive, send=sender)
return sender.build_asgi_response()

async def predict(self, *args, **kwargs):
"""Perform inference directly without HTTP."""
return await self.dag_handle.remote(*args, **kwargs)
return await self.dags[self.MATCH_ALL_ROUTE_PREFIX].remote(*args, **kwargs)

async def predict_with_route(self, route_path, *args, **kwargs):
"""Perform inference directly without HTTP for multi dags."""
if route_path not in self.dags:
raise RayServeException(f"{route_path} does not exist in dags routes")
return await self.dags[route_path].remote(*args, **kwargs)
11 changes: 7 additions & 4 deletions python/ray/serve/tests/test_deploy_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,20 @@ def check_switched():

@pytest.mark.parametrize("prefixes", [[None, "/f", None], ["/f", None, "/f"]])
def test_deploy_nullify_route_prefix(serve_instance, prefixes):
# With multi dags support, dag driver will receive all route
# prefix when route_prefix is "None", since "None" will be converted
# to "/" internally.
# Note: the expose http endpoint will still be removed for internal
# dag node by setting "None" to route_prefix
@serve.deployment
def f(*args):
return "got me"

for prefix in prefixes:
dag = DAGDriver.options(route_prefix=prefix).bind(f.bind())
handle = serve.run(dag)
if prefix is None:
assert requests.get("http://localhost:8000/f").status_code == 404
else:
assert requests.get("http://localhost:8000/f").text == '"got me"'
assert requests.get("http://localhost:8000/f").status_code == 200
assert requests.get("http://localhost:8000/f").text == '"got me"'
assert ray.get(handle.predict.remote()) == "got me"


Expand Down
158 changes: 68 additions & 90 deletions python/ray/serve/tests/test_deployment_graph_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import contextlib
import io
import sys
import numpy as np
from pydantic import BaseModel

import pytest
Expand All @@ -10,11 +7,9 @@
from starlette.testclient import TestClient

from ray.serve.drivers import DAGDriver, SimpleSchemaIngress, _load_http_adapter
from ray.serve.http_adapters import json_request
from ray.serve.dag import InputNode
from ray import serve
import ray
from ray._private.test_utils import wait_for_condition


def my_resolver(a: int):
Expand Down Expand Up @@ -95,107 +90,90 @@ def echo(inp):
return inp


def test_dag_driver_default(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)
async def json_resolver(request: starlette.requests.Request):
return await request.json()

handle = serve.run(DAGDriver.bind(dag))
assert ray.get(handle.predict.remote(42)) == 42

resp = requests.post("http://127.0.0.1:8000/", json={"array": [1]})
print(resp.text)
def test_multi_dag(serve_instance):
"""Test multi dags within dag deployment"""

resp.raise_for_status()
assert resp.json() == "starlette!"
@serve.deployment
class D1:
def forward(self):
return "D1"

@serve.deployment
class D2:
def forward(self):
return "D2"

async def resolver(my_custom_param: int):
return my_custom_param
@serve.deployment
class D3:
def __call__(self):
return "D3"

@serve.deployment
def D4():
return "D4"

def test_dag_driver_custom_schema(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)

handle = serve.run(DAGDriver.bind(dag, http_adapter=resolver))
assert ray.get(handle.predict.remote(42)) == 42

resp = requests.get("http://127.0.0.1:8000/?my_custom_param=100")
print(resp.text)
resp.raise_for_status()
assert resp.json() == 100


def test_dag_driver_custom_pydantic_schema(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)

handle = serve.run(DAGDriver.bind(dag, http_adapter=MyType))
assert ray.get(handle.predict.remote(MyType(a=1, b="str"))) == MyType(a=1, b="str")

resp = requests.post("http://127.0.0.1:8000/", json={"a": 1, "b": "str"})
print(resp.text)
resp.raise_for_status()
assert resp.json() == {"a": 1, "b": "str"}


@serve.deployment
def combine(*args):
return list(args)


def test_dag_driver_partial_input(serve_instance):
with InputNode() as inp:
dag = DAGDriver.bind(
combine.bind(echo.bind(inp[0]), echo.bind(inp[1]), echo.bind(inp[2])),
http_adapter=json_request,
)
d1 = D1.bind()
d2 = D2.bind()
d3 = D3.bind()
d4 = D4.bind()
dag = DAGDriver.bind(
{
"/my_D1": d1.forward.bind(),
"/my_D2": d2.forward.bind(),
"/my_D3": d3,
"/my_D4": d4,
}
)
handle = serve.run(dag)
assert ray.get(handle.predict.remote([1, 2, [3, 4]])) == [1, 2, [3, 4]]
assert ray.get(handle.predict.remote(1, 2, [3, 4])) == [1, 2, [3, 4]]

resp = requests.post("http://127.0.0.1:8000/", json=[1, 2, [3, 4]])
print(resp.text)
resp.raise_for_status()
assert resp.json() == [1, 2, [3, 4]]


@serve.deployment
def return_np_int(_):
return [np.int64(42)]


def test_driver_np_serializer(serve_instance):
# https://github.com/ray-project/ray/pull/24215#issuecomment-1115237058
with InputNode() as inp:
dag = DAGDriver.bind(return_np_int.bind(inp))
serve.run(dag)
assert requests.get("http://127.0.0.1:8000/").json() == [42]
for i in range(1, 5):
assert ray.get(handle.predict_with_route.remote(f"/my_D{i}")) == f"D{i}"
assert requests.post(f"http://127.0.0.1:8000/my_D{i}", json=1).json() == f"D{i}"
assert requests.get(f"http://127.0.0.1:8000/my_D{i}", json=1).json() == f"D{i}"


def test_dag_driver_sync_warning(serve_instance):
with InputNode() as inp:
dag = echo.bind(inp)
def test_multi_dag_with_inputs(serve_instance):
@serve.deployment
class D1:
def forward(self, input):
return input

log_file = io.StringIO()
with contextlib.redirect_stderr(log_file):
@serve.deployment
class D2:
def forward(self, input1, input2):
return input1 + input2

handle = serve.run(DAGDriver.bind(dag))
assert ray.get(handle.predict.remote(42)) == 42
@serve.deployment
def D3(input):
return input

def wait_for_request_success_log():
lines = log_file.getvalue().splitlines()
for line in lines:
if "DAGDriver" in line and "HANDLE predict OK" in line:
return True
return False
d1 = D1.bind()
d2 = D2.bind()

wait_for_condition(wait_for_request_success_log)

assert (
"You are retrieving a sync handle inside an asyncio loop."
not in log_file.getvalue()
with InputNode() as dag_input:
dag = DAGDriver.bind(
{
"/my_D1": d1.forward.bind(dag_input),
"/my_D2": d2.forward.bind(dag_input[0], dag_input[1]),
"/my_D3": D3.bind(dag_input),
},
http_adapter=json_resolver,
)
handle = serve.run(dag)

assert ray.get(handle.predict_with_route.remote("/my_D1", 1)) == 1
assert ray.get(handle.predict_with_route.remote("/my_D2", 10, 2)) == 12
assert ray.get(handle.predict_with_route.remote("/my_D3", 100)) == 100
assert requests.post("http://127.0.0.1:8000/my_D1", json=1).json() == 1
assert requests.post("http://127.0.0.1:8000/my_D2", json=[1, 2]).json() == 3
assert requests.post("http://127.0.0.1:8000/my_D3", json=100).json() == 100
assert requests.get("http://127.0.0.1:8000/my_D1", json=1).json() == 1
assert requests.get("http://127.0.0.1:8000/my_D2", json=[1, 2]).json() == 3
assert requests.get("http://127.0.0.1:8000/my_D3", json=100).json() == 100


if __name__ == "__main__":
Expand Down
53 changes: 50 additions & 3 deletions python/ray/serve/tests/test_http_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import requests
from ray.serve.drivers import DAGDriver
from fastapi import FastAPI, Request
from starlette.responses import RedirectResponse

Expand Down Expand Up @@ -47,7 +48,7 @@ def test_routes_healthz(serve_instance):
assert resp.content == b"success"


def test_routes_endpoint(serve_instance):
def test_routes_endpoint_legacy(serve_instance):
@serve.deployment
class D1:
pass
Expand Down Expand Up @@ -93,6 +94,31 @@ class D3:
assert routes["/hello"] == "D3", routes


def test_routes_endpoint(serve_instance):
@serve.deployment
class D1:
def __call__(self):
return "D1"

@serve.deployment
class D2:
def __call__(self):
return "D2"

dag = DAGDriver.bind({"/D1": D1.bind(), "/hello/world": D2.bind()})
serve.run(dag)

routes = requests.get("http://localhost:8000/-/routes").json()

assert len(routes) == 1, routes
assert "/" in routes, routes

assert requests.get("http://localhost:8000/D1").json() == "D1"
assert requests.get("http://localhost:8000/D1").status_code == 200
assert requests.get("http://localhost:8000/hello/world").json() == "D2"
assert requests.get("http://localhost:8000/hello/world").status_code == 200


def test_deployment_without_route(serve_instance):
@serve.deployment(route_prefix=None)
class D:
Expand Down Expand Up @@ -193,6 +219,27 @@ def subpath(self, p: str):
check_req("/hello/world/again/hi") == '"hi"'


def test_multi_dag_with_wrong_route(serve_instance):
@serve.deployment
class D1:
def __call__(self):
return "D1"

@serve.deployment
class D2:
def __call__(self):
return "D2"

dag = DAGDriver.bind({"/D1": D1.bind(), "/hello/world": D2.bind()})

serve.run(dag)

assert requests.get("http://localhost:8000/D1").status_code == 200
assert requests.get("http://localhost:8000/hello/world").status_code == 200
assert requests.get("http://localhost:8000/not_exist").status_code == 404
assert requests.get("http://localhost:8000/").status_code == 404


@pytest.mark.parametrize("base_path", ["", "subpath"])
def test_redirect(serve_instance, base_path):
app = FastAPI()
Expand Down Expand Up @@ -241,7 +288,7 @@ def test_default_error_handling(serve_instance):
def f():
1 / 0

f.deploy()
serve.run(f.bind())
r = requests.get("http://localhost:8000/f")
assert r.status_code == 500
assert "ZeroDivisionError" in r.text, r.text
Expand All @@ -255,7 +302,7 @@ def h():
ray.get(intentional_kill.remote(ray.get_runtime_context().current_actor))
time.sleep(100) # Don't return here to leave time for actor exit.

h.deploy()
serve.run(h.bind())
r = requests.get("http://localhost:8000/h")
assert r.status_code == 500
assert "retries" in r.text, r.text
Expand Down

0 comments on commit f975ab9

Please sign in to comment.