Skip to content

Commit

Permalink
[serve] Enable streaming over ServeHandles (#37114)
Browse files Browse the repository at this point in the history
Adds support for `handle.options(stream=True)` to make streaming calls using the `RayServeHandle`.

This required two changes:

- Adding `stream=True` to the `.options` codepaths.
- Adding a branch in `handle_request_streaming` to handle calls from the HTTP proxy as well as other handle calls.

I also refactored some of the code around exception handling/metrics into a common context manager that's used by both codepaths.
  • Loading branch information
edoakes authored Jul 7, 2023
1 parent b56aab4 commit 95a51ce
Show file tree
Hide file tree
Showing 12 changed files with 646 additions and 159 deletions.
9 changes: 9 additions & 0 deletions python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ py_test(
)


py_test(
name = "test_handle_streaming",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)


py_test(
name = "test_kv_store",
size = "small",
Expand Down
5 changes: 0 additions & 5 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ def get_handle(
missing_ok: Optional[bool] = False,
sync: bool = True,
_is_for_http_requests: bool = False,
_stream: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.
Expand All @@ -500,8 +499,6 @@ def get_handle(
that's only usable in asyncio loop.
_is_for_http_requests: Indicates that this handle will be used
to send HTTP requests from the proxy to ingress deployment replicas.
_stream: Indicates that this handle should use
`num_returns="streaming"`.
Returns:
RayServeHandle
Expand All @@ -521,14 +518,12 @@ def get_handle(
self._controller,
deployment_name,
_is_for_http_requests=_is_for_http_requests,
_stream=_stream,
)
else:
handle = RayServeHandle(
self._controller,
deployment_name,
_is_for_http_requests=_is_for_http_requests,
_stream=_stream,
)

self.handle_cache[cache_key] = handle
Expand Down
8 changes: 8 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,11 @@ class MultiplexedReplicaInfo:
deployment_name: str
replica_tag: str
model_ids: List[str]


@dataclass
class StreamingHTTPRequest:
"""Sent from the HTTP proxy to replicas on the streaming codepath."""

pickled_asgi_scope: bytes
http_proxy_handle: ActorHandle
18 changes: 12 additions & 6 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
set_socket_reuse_port,
validate_http_proxy_callback_return,
)
from ray.serve._private.common import EndpointInfo, EndpointTag, ApplicationName, NodeId
from ray.serve._private.common import (
ApplicationName,
EndpointInfo,
EndpointTag,
NodeId,
StreamingHTTPRequest,
)
from ray.serve._private.constants import (
SERVE_LOGGER_NAME,
SERVE_MULTIPLEXED_MODEL_ID,
Expand Down Expand Up @@ -122,8 +128,7 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
if endpoint in self.handles:
existing_handles.remove(endpoint)
else:
self.handles[endpoint] = self._get_handle(
endpoint,
self.handles[endpoint] = self._get_handle(endpoint).options(
# Streaming codepath isn't supported for Java.
stream=(
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
Expand Down Expand Up @@ -224,13 +229,12 @@ def __init__(
extra={"log_to_stderr": False},
)

def get_handle(name, stream: bool = False):
def get_handle(name):
return serve.context.get_global_client().get_handle(
name,
sync=False,
missing_ok=True,
_is_for_http_requests=True,
_stream=stream,
)

self.prefix_router = LongestPrefixRouter(get_handle)
Expand Down Expand Up @@ -697,7 +701,9 @@ async def _assign_request_with_timeout(
`disconnected_task` is expected to be done if the client disconnects; in this
case, we will abort assigning a replica and return `None`.
"""
assignment_task = handle.remote(pickle.dumps(scope), self.self_actor_handle)
assignment_task = handle.remote(
StreamingHTTPRequest(pickle.dumps(scope), self.self_actor_handle)
)
done, _ = await asyncio.wait(
[assignment_task, disconnected_task],
return_when=FIRST_COMPLETED,
Expand Down
Loading

0 comments on commit 95a51ce

Please sign in to comment.