Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Enable streaming over ServeHandles #37114

Merged
merged 16 commits into from
Jul 7, 2023
9 changes: 9 additions & 0 deletions python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,15 @@ py_test(
)


py_test(
name = "test_handle_streaming",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
)


py_test(
name = "test_kv_store",
size = "small",
Expand Down
5 changes: 0 additions & 5 deletions python/ray/serve/_private/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,6 @@ def get_handle(
missing_ok: Optional[bool] = False,
sync: bool = True,
_is_for_http_requests: bool = False,
_stream: bool = False,
) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service deployment to invoke it from Python.

Expand All @@ -500,8 +499,6 @@ def get_handle(
that's only usable in asyncio loop.
_is_for_http_requests: Indicates that this handle will be used
to send HTTP requests from the proxy to ingress deployment replicas.
_stream: Indicates that this handle should use
`num_returns="streaming"`.

Returns:
RayServeHandle
Expand All @@ -521,14 +518,12 @@ def get_handle(
self._controller,
deployment_name,
_is_for_http_requests=_is_for_http_requests,
_stream=_stream,
)
else:
handle = RayServeHandle(
self._controller,
deployment_name,
_is_for_http_requests=_is_for_http_requests,
_stream=_stream,
)

self.handle_cache[cache_key] = handle
Expand Down
6 changes: 6 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,9 @@ class MultiplexedReplicaInfo:
deployment_name: str
replica_tag: str
model_ids: List[str]


@dataclass
class StreamingHTTPRequest:
pickled_asgi_scope: bytes
http_proxy_handle: ActorHandle
18 changes: 12 additions & 6 deletions python/ray/serve/_private/http_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
set_socket_reuse_port,
validate_http_proxy_callback_return,
)
from ray.serve._private.common import EndpointInfo, EndpointTag, ApplicationName, NodeId
from ray.serve._private.common import (
ApplicationName,
EndpointInfo,
EndpointTag,
NodeId,
StreamingHTTPRequest,
)
from ray.serve._private.constants import (
SERVE_LOGGER_NAME,
SERVE_MULTIPLEXED_MODEL_ID,
Expand Down Expand Up @@ -122,8 +128,7 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None:
if endpoint in self.handles:
existing_handles.remove(endpoint)
else:
self.handles[endpoint] = self._get_handle(
endpoint,
self.handles[endpoint] = self._get_handle(endpoint).options(
# Streaming codepath isn't supported for Java.
stream=(
RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING
Expand Down Expand Up @@ -224,13 +229,12 @@ def __init__(
extra={"log_to_stderr": False},
)

def get_handle(name, stream: bool = False):
def get_handle(name):
return serve.context.get_global_client().get_handle(
name,
sync=False,
missing_ok=True,
_is_for_http_requests=True,
_stream=stream,
)

self.prefix_router = LongestPrefixRouter(get_handle)
Expand Down Expand Up @@ -697,7 +701,9 @@ async def _assign_request_with_timeout(
`disconnected_task` is expected to be done if the client disconnects; in this
case, we will abort assigning a replica and return `None`.
"""
assignment_task = handle.remote(pickle.dumps(scope), self.self_actor_handle)
assignment_task = handle.remote(
StreamingHTTPRequest(pickle.dumps(scope), self.self_actor_handle)
)
done, _ = await asyncio.wait(
[assignment_task, disconnected_task],
return_when=FIRST_COMPLETED,
Expand Down
Loading