From 95a51ce1c7ee4b572b4a9d4eadc88572dabf0fff Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Fri, 7 Jul 2023 13:24:07 -0500 Subject: [PATCH] [serve] Enable streaming over `ServeHandle`s (#37114) Adds support for `handle.options(stream=True)` to make streaming calls using the `RayServeHandle`. This required two changes: - Adding `stream=True` to the `.options` codepaths. - Adding a branch in `handle_request_streaming` to handle calls from the HTTP proxy as well as other handle calls. I also refactored some of the code around exception handling/metrics into a common context manager that's used by both codepaths. --- python/ray/serve/BUILD | 9 + python/ray/serve/_private/client.py | 5 - python/ray/serve/_private/common.py | 8 + python/ray/serve/_private/http_proxy.py | 18 +- python/ray/serve/_private/replica.py | 332 +++++++++++------- python/ray/serve/_private/utils.py | 22 ++ python/ray/serve/handle.py | 56 ++- .../ray/serve/tests/test_handle_streaming.py | 239 +++++++++++++ .../serve/tests/test_http_prefix_matching.py | 12 +- python/ray/serve/tests/test_metrics.py | 1 + .../serve/tests/test_streaming_response.py | 67 ++++ python/ray/serve/tests/test_util.py | 36 ++ 12 files changed, 646 insertions(+), 159 deletions(-) create mode 100644 python/ray/serve/tests/test_handle_streaming.py diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index 5c58e6df244e..4b250657a556 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -261,6 +261,15 @@ py_test( ) +py_test( + name = "test_handle_streaming", + size = "medium", + srcs = serve_tests_srcs, + tags = ["exclusive", "team:serve"], + deps = [":serve_lib"], +) + + py_test( name = "test_kv_store", size = "small", diff --git a/python/ray/serve/_private/client.py b/python/ray/serve/_private/client.py index 00f67fe4dd44..7c9a61358f27 100644 --- a/python/ray/serve/_private/client.py +++ b/python/ray/serve/_private/client.py @@ -487,7 +487,6 @@ def get_handle( missing_ok: Optional[bool] = False, sync: bool = True, _is_for_http_requests: bool = False, - _stream: bool = False, ) -> Union[RayServeHandle, RayServeSyncHandle]: """Retrieve RayServeHandle for service deployment to invoke it from Python. @@ -500,8 +499,6 @@ def get_handle( that's only usable in asyncio loop. _is_for_http_requests: Indicates that this handle will be used to send HTTP requests from the proxy to ingress deployment replicas. - _stream: Indicates that this handle should use - `num_returns="streaming"`. Returns: RayServeHandle @@ -521,14 +518,12 @@ def get_handle( self._controller, deployment_name, _is_for_http_requests=_is_for_http_requests, - _stream=_stream, ) else: handle = RayServeHandle( self._controller, deployment_name, _is_for_http_requests=_is_for_http_requests, - _stream=_stream, ) self.handle_cache[cache_key] = handle diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 3acfb7d0a01c..036dc49fbf78 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -391,3 +391,11 @@ class MultiplexedReplicaInfo: deployment_name: str replica_tag: str model_ids: List[str] + + +@dataclass +class StreamingHTTPRequest: + """Sent from the HTTP proxy to replicas on the streaming codepath.""" + + pickled_asgi_scope: bytes + http_proxy_handle: ActorHandle diff --git a/python/ray/serve/_private/http_proxy.py b/python/ray/serve/_private/http_proxy.py index bb8b51797fc4..f8dc069651d6 100644 --- a/python/ray/serve/_private/http_proxy.py +++ b/python/ray/serve/_private/http_proxy.py @@ -32,7 +32,13 @@ set_socket_reuse_port, validate_http_proxy_callback_return, ) -from ray.serve._private.common import EndpointInfo, EndpointTag, ApplicationName, NodeId +from ray.serve._private.common import ( + ApplicationName, + EndpointInfo, + EndpointTag, + NodeId, + StreamingHTTPRequest, +) from ray.serve._private.constants import ( SERVE_LOGGER_NAME, SERVE_MULTIPLEXED_MODEL_ID, @@ -122,8 +128,7 @@ def update_routes(self, endpoints: Dict[EndpointTag, EndpointInfo]) -> None: if endpoint in self.handles: existing_handles.remove(endpoint) else: - self.handles[endpoint] = self._get_handle( - endpoint, + self.handles[endpoint] = self._get_handle(endpoint).options( # Streaming codepath isn't supported for Java. stream=( RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING @@ -224,13 +229,12 @@ def __init__( extra={"log_to_stderr": False}, ) - def get_handle(name, stream: bool = False): + def get_handle(name): return serve.context.get_global_client().get_handle( name, sync=False, missing_ok=True, _is_for_http_requests=True, - _stream=stream, ) self.prefix_router = LongestPrefixRouter(get_handle) @@ -697,7 +701,9 @@ async def _assign_request_with_timeout( `disconnected_task` is expected to be done if the client disconnects; in this case, we will abort assigning a replica and return `None`. """ - assignment_task = handle.remote(pickle.dumps(scope), self.self_actor_handle) + assignment_task = handle.remote( + StreamingHTTPRequest(pickle.dumps(scope), self.self_actor_handle) + ) done, _ = await asyncio.wait( [assignment_task, disconnected_task], return_when=FIRST_COMPLETED, diff --git a/python/ray/serve/_private/replica.py b/python/ray/serve/_private/replica.py index 8cdab9496e68..45adca7fa821 100644 --- a/python/ray/serve/_private/replica.py +++ b/python/ray/serve/_private/replica.py @@ -1,5 +1,6 @@ import aiorwlock import asyncio +from contextlib import asynccontextmanager from importlib import import_module import inspect import logging @@ -25,6 +26,7 @@ CONTROL_PLANE_CONCURRENCY_GROUP, ReplicaTag, ServeComponentType, + StreamingHTTPRequest, ) from ray.serve.config import DeploymentConfig from ray.serve._private.constants import ( @@ -55,6 +57,7 @@ from ray.serve._private.router import RequestMetadata from ray.serve._private.utils import ( parse_import_path, + wrap_generator_function_in_async_if_needed, wrap_to_ray_error, merge_dict, MetricsPusher, @@ -236,7 +239,7 @@ async def handle_request( buffered_receive = make_buffered_asgi_receive(request.body) request_args = (scope, buffered_receive, buffered_send) - result = await self.replica.handle_request( + result = await self.replica.call_user_method( request_metadata, request_args, request_kwargs ) @@ -246,38 +249,28 @@ async def handle_request( # Returns a small object for router to track request status. return b"", result - async def handle_request_streaming( + async def _handle_http_request_generator( self, - pickled_request_metadata: bytes, - pickled_asgi_scope: bytes, - http_proxy_handle: ActorHandle, + request_metadata: RequestMetadata, + request: StreamingHTTPRequest, ) -> AsyncGenerator[Message, None]: - """Handle a request and stream the results to the caller. + """Handle an HTTP request and stream ASGI messages to the caller. - This is used by the HTTP proxy for experimental StreamingResponse support. - - This generator yields ASGI-compliant messages sent via an ASGI send - interface. This allows us to return the messages back to the HTTP proxy as - they're sent by user code (e.g., the FastAPI wrapper). + This is a generator that yields ASGI-compliant messages sent by user code + via an ASGI send interface. """ - request_metadata = pickle.loads(pickled_request_metadata) - if not request_metadata.is_http_request: - raise NotImplementedError( - "Only HTTP requests are currently supported over streaming." - ) - receiver_task = None - handle_request_task = None + call_user_method_task = None wait_for_message_task = None try: receiver = ASGIReceiveProxy( - request_metadata.request_id, http_proxy_handle + request_metadata.request_id, request.http_proxy_handle ) receiver_task = self._event_loop.create_task( receiver.fetch_until_disconnect() ) - scope = pickle.loads(pickled_asgi_scope) + scope = pickle.loads(request.pickled_asgi_scope) asgi_queue_send = ASGIMessageQueue() request_args = (scope, receiver, asgi_queue_send) request_kwargs = {} @@ -286,8 +279,8 @@ async def handle_request_streaming( # this task will use the provided ASGI send interface to send its HTTP # the response. We will poll for the sent messages and yield them back # to the caller. - handle_request_task = self._event_loop.create_task( - self.replica.handle_request( + call_user_method_task = self._event_loop.create_task( + self.replica.call_user_method( request_metadata, request_args, request_kwargs ) ) @@ -297,7 +290,7 @@ async def handle_request_streaming( asgi_queue_send.wait_for_message() ) done, _ = await asyncio.wait( - [handle_request_task, wait_for_message_task], + [call_user_method_task, wait_for_message_task], return_when=asyncio.FIRST_COMPLETED, ) # Consume and yield all available messages in the queue. @@ -306,20 +299,23 @@ async def handle_request_streaming( # know it's safe for these messages containing primitive types. yield pickle.dumps(asgi_queue_send.get_messages_nowait()) - # Exit once `handle_request` has finished. In this case, all + # Exit once `call_user_method` has finished. In this case, all # messages must have already been sent. - if handle_request_task in done: + if call_user_method_task in done: break - e = handle_request_task.exception() + e = call_user_method_task.exception() if e is not None: raise e from None finally: if receiver_task is not None: receiver_task.cancel() - if handle_request_task is not None and not handle_request_task.done(): - handle_request_task.cancel() + if ( + call_user_method_task is not None + and not call_user_method_task.done() + ): + call_user_method_task.cancel() if ( wait_for_message_task is not None @@ -327,6 +323,29 @@ async def handle_request_streaming( ): wait_for_message_task.cancel() + async def handle_request_streaming( + self, + pickled_request_metadata: bytes, + *request_args, + **request_kwargs, + ) -> AsyncGenerator[Any, None]: + """Generator that is the entrypoint for all `stream=True` handle calls.""" + request_metadata = pickle.loads(pickled_request_metadata) + if request_metadata.is_http_request: + assert len(request_args) == 1 and isinstance( + request_args[0], StreamingHTTPRequest + ) + generator = self._handle_http_request_generator( + request_metadata, request_args[0] + ) + else: + generator = self.replica.call_user_method_generator( + request_metadata, request_args, request_kwargs + ) + + async for result in generator: + yield result + async def handle_request_from_java( self, proto_request_metadata: bytes, @@ -342,7 +361,7 @@ async def handle_request_from_java( proto.request_id, proto.endpoint, call_method=proto.call_method ) request_args = request_args[0] - return await self.replica.handle_request( + return await self.replica.call_user_method( request_metadata, request_args, request_kwargs ) @@ -589,79 +608,6 @@ async def send_user_result_over_asgi( else: await result(scope, receive, send) - async def invoke_single( - self, - request_metadata: RequestMetadata, - request_args: Tuple[Any], - request_kwargs: Dict[str, Any], - ) -> Tuple[Any, bool]: - """Executes the provided request on this replica. - - Returns the user-provided output and a boolean indicating if the - request succeeded (user code didn't raise an exception). - """ - logger.info( - f"Started executing request {request_metadata.request_id}", - extra={"log_to_stderr": False}, - ) - - if request_metadata.is_http_request: - # For HTTP requests we always expect (scope, receive, send) as args. - assert len(request_args) == 3 - scope, receive, send = request_args - - if isinstance(self.callable, ASGIAppReplicaWrapper): - request_args = (scope, receive, send) - else: - request_args = (Request(scope, receive, send),) - - method_to_call = None - success = True - try: - runner_method = self.get_runner_method(request_metadata) - method_to_call = sync_to_async(runner_method) - result = None - - # Edge case to support empty HTTP handlers: don't pass the Request - # argument if the callable has no parameters. - if ( - request_metadata.is_http_request - and len(inspect.signature(runner_method).parameters) == 0 - ): - request_args, request_kwargs = tuple(), {} - - result = await method_to_call(*request_args, **request_kwargs) - - except Exception as e: - logger.exception(f"Request failed due to {type(e).__name__}:") - success = False - - # If the debugger is enabled, drop into the remote pdb here. - if ray.util.pdb._is_ray_debugger_enabled(): - ray.util.pdb._post_mortem() - - function_name = "unknown" - if method_to_call is not None: - function_name = method_to_call.__name__ - result = wrap_to_ray_error(function_name, e) - if request_metadata.is_http_request: - error_message = f"Unexpected error, traceback: {result}." - result = starlette.responses.Response(error_message, status_code=500) - - if request_metadata.is_http_request and not isinstance( - self.callable, ASGIAppReplicaWrapper - ): - # For the FastAPI codepath, the response has already been sent over the ASGI - # interface, but for the vanilla deployment codepath we need to send it. - await self.send_user_result_over_asgi(result, scope, receive, send) - - if success: - self.request_counter.inc(tags={"route": request_metadata.route}) - else: - self.error_counter.inc(tags={"route": request_metadata.route}) - - return result, success - async def reconfigure(self, deployment_config: DeploymentConfig): old_user_config = self.deployment_config.user_config self.deployment_config = deployment_config @@ -673,7 +619,7 @@ async def reconfigure(self, deployment_config: DeploymentConfig): await self.update_user_config(deployment_config.user_config) async def update_user_config(self, user_config: Any): - async with self.rwlock.writer_lock: + async with self.rwlock.writer: if user_config is not None: if self.is_function: raise ValueError( @@ -692,44 +638,166 @@ async def update_user_config(self, user_config: Any): ) await reconfigure_method(user_config) - async def handle_request( + @asynccontextmanager + async def wrap_user_method_call( + self, + request_metadata: RequestMetadata, + *, + acquire_reader_lock: bool = True, + ): + """Context manager that should be used to wrap user method calls. + + This sets up the serve request context, grabs the reader lock to avoid mutating + user_config during method calls, and records metrics based on the result of the + method. + """ + # Set request context variables for subsequent handle so that + # handle can pass the correct request context to subsequent replicas. + ray.serve.context._serve_request_context.set( + ray.serve.context.RequestContext( + request_metadata.route, + request_metadata.request_id, + self.app_name, + request_metadata.multiplexed_model_id, + ) + ) + + logger.info( + f"Started executing request {request_metadata.request_id}", + extra={"log_to_stderr": False}, + ) + start_time = time.time() + user_exception = None + try: + # TODO(edoakes): this is only here because there is an issue where async + # generators in actors have the `asyncio.current_task()` change between + # iterations: https://github.com/ray-project/ray/issues/37147. `aiorwlock` + # relies on the current task being stable, so it raises an exception. + # This flag should be removed once the above issue is closed. + if acquire_reader_lock: + async with self.rwlock.reader: + yield + else: + yield + except Exception as e: + user_exception = e + logger.exception(f"Request failed due to {type(e).__name__}:") + if ray.util.pdb._is_ray_debugger_enabled(): + ray.util.pdb._post_mortem() + + latency_ms = (time.time() - start_time) * 1000 + self.processing_latency_tracker.observe( + latency_ms, tags={"route": request_metadata.route} + ) + logger.info( + access_log_msg( + method=request_metadata.call_method, + status="OK" if user_exception is None else "ERROR", + latency_ms=latency_ms, + ) + ) + if user_exception is None: + self.request_counter.inc(tags={"route": request_metadata.route}) + else: + self.error_counter.inc(tags={"route": request_metadata.route}) + raise user_exception from None + + async def call_user_method( self, request_metadata: RequestMetadata, request_args: Tuple[Any], request_kwargs: Dict[str, Any], ) -> Any: - async with self.rwlock.reader_lock: - # Set request context variables for subsequent handle so that - # handle can pass the correct request context to subsequent replicas. - ray.serve.context._serve_request_context.set( - ray.serve.context.RequestContext( - request_metadata.route, - request_metadata.request_id, - self.app_name, - request_metadata.multiplexed_model_id, - ) - ) + """Call a user method that is *not* expected to be a generator. - start_time = time.time() - result, success = await self.invoke_single( - request_metadata, - request_args, - request_kwargs, - ) - latency_ms = (time.time() - start_time) * 1000 - self.processing_latency_tracker.observe( - latency_ms, tags={"route": request_metadata.route} - ) - logger.info( - access_log_msg( - method=request_metadata.call_method, - status="OK" if success else "ERROR", - latency_ms=latency_ms, - ) - ) + Raises any exception raised by the user code so it can be propagated as a + `RayTaskError`. + """ + async with self.wrap_user_method_call(request_metadata): + if request_metadata.is_http_request: + # For HTTP requests we always expect (scope, receive, send) as args. + assert len(request_args) == 3 + scope, receive, send = request_args + + if isinstance(self.callable, ASGIAppReplicaWrapper): + request_args = (scope, receive, send) + else: + request_args = (Request(scope, receive, send),) + + runner_method = None + try: + runner_method = self.get_runner_method(request_metadata) + if inspect.isgeneratorfunction( + runner_method + ) or inspect.isasyncgenfunction(runner_method): + raise TypeError( + f"Method '{runner_method.__name__}' is a generator. You must " + "use `handle.options(stream=True)` to call generator methods " + "on a deployment." + ) + method_to_call = sync_to_async(runner_method) + + # Edge case to support empty HTTP handlers: don't pass the Request + # argument if the callable has no parameters. + if ( + request_metadata.is_http_request + and len(inspect.signature(runner_method).parameters) == 0 + ): + request_args, request_kwargs = tuple(), {} + + result = await method_to_call(*request_args, **request_kwargs) + + except Exception as e: + function_name = "unknown" + if runner_method is not None: + function_name = runner_method.__name__ + e = wrap_to_ray_error(function_name, e) + if request_metadata.is_http_request: + result = starlette.responses.Response( + f"Unexpected error, traceback: {e}.", status_code=500 + ) + await self.send_user_result_over_asgi(result, scope, receive, send) + + raise e from None + + if request_metadata.is_http_request and not isinstance( + self.callable, ASGIAppReplicaWrapper + ): + # For the FastAPI codepath, the response has already been sent over the + # ASGI interface, but for the vanilla deployment codepath we need to + # send it. + await self.send_user_result_over_asgi(result, scope, receive, send) return result + async def call_user_method_generator( + self, + request_metadata: RequestMetadata, + request_args: Tuple[Any], + request_kwargs: Dict[str, Any], + ) -> AsyncGenerator[Any, None]: + """Call a user method that is expected to be a generator. + + Raises any exception raised by the user code so it can be propagated as a + `RayTaskError`. + """ + # TODO(edoakes): this is only here because there is an issue where async + # generators in actors have the `asyncio.current_task()` change between + # iterations: https://github.com/ray-project/ray/issues/37147. `aiorwlock` + # relies on the current task being stable, so it raises an exception. + # This flag should be removed once the above issue is closed. + async with self.wrap_user_method_call( + request_metadata, acquire_reader_lock=False + ): + assert ( + not request_metadata.is_http_request + ), "HTTP requests should go through `call_user_method`." + user_method = self.get_runner_method(request_metadata) + method_to_call = wrap_generator_function_in_async_if_needed(user_method) + + async for result in method_to_call(*request_args, **request_kwargs): + yield result + async def prepare_for_shutdown(self): """Perform graceful shutdown. diff --git a/python/ray/serve/_private/utils.py b/python/ray/serve/_private/utils.py index ccf11eac599b..e5a27844aa55 100644 --- a/python/ray/serve/_private/utils.py +++ b/python/ray/serve/_private/utils.py @@ -11,6 +11,7 @@ from functools import wraps from typing import ( Any, + AsyncGenerator, Callable, Dict, Iterable, @@ -723,3 +724,24 @@ def calculate_remaining_timeout( time_since_start_s = curr_time_s - start_time_s return max(0, timeout_s - time_since_start_s) + + +def wrap_generator_function_in_async_if_needed( + f: Callable, +) -> Callable[[Any], AsyncGenerator]: + """Given a callable, make sure it returns an async generator. + + If the callable is not a generator at all, raise a `TypeError`. + """ + if inspect.isasyncgenfunction(f): + return f + elif inspect.isgeneratorfunction(f): + + @wraps(f) + async def async_gen_wrapper(*args, **kwargs): + for result in f(*args, **kwargs): + yield result + + return async_gen_wrapper + else: + raise TypeError(f"Method '{f.__name__}' is not a generator.") diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index e2c45194b4f3..d5669a6048b8 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -66,9 +66,13 @@ def _create_or_get_async_loop_in_thread(): @PublicAPI(stability="beta") @dataclass(frozen=True) class HandleOptions: - """Options for each ServeHandle instances. These fields are immutable.""" + """Options for each ServeHandle instance. + + These fields can be changed by calling `.options()` on a handle. + """ method_name: str = "__call__" + stream: bool = False @PublicAPI(stability="beta") @@ -120,14 +124,12 @@ def __init__( *, _router: Optional[Router] = None, _is_for_http_requests: bool = False, - _stream: bool = False, ): self.controller_handle = controller_handle self.deployment_name = deployment_name self.handle_options = handle_options or HandleOptions() self.handle_tag = f"{self.deployment_name}#{get_random_letters()}" self._is_for_http_requests = _is_for_http_requests - self._stream = _stream self.request_counter = metrics.Counter( "serve_handle_request_counter", @@ -169,11 +171,15 @@ def _options( *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, + stream: Union[bool, DEFAULT] = DEFAULT.VALUE, ): new_options_dict = self.handle_options.__dict__.copy() user_modified_options_dict = { key: value - for key, value in zip(["method_name"], [method_name]) + for key, value in [ + ("method_name", method_name), + ("stream", stream), + ] if value != DEFAULT.VALUE } new_options_dict.update(user_modified_options_dict) @@ -192,7 +198,6 @@ def _options( new_options, _router=self.router, _is_for_http_requests=self._is_for_http_requests, - _stream=self._stream, ) def options( @@ -200,6 +205,7 @@ def options( *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, + stream: Union[bool, DEFAULT] = DEFAULT.VALUE, ) -> "RayServeHandle": """Set options for this handle and return an updated copy of it. @@ -214,7 +220,9 @@ def options( multiplexed_model_id="model:v1").remote(*args) """ return self._options( - method_name=method_name, multiplexed_model_id=multiplexed_model_id + method_name=method_name, + multiplexed_model_id=multiplexed_model_id, + stream=stream, ) def _remote(self, deployment_name, handle_options, args, kwargs) -> Coroutine: @@ -227,7 +235,7 @@ def _remote(self, deployment_name, handle_options, args, kwargs) -> Coroutine: route=_request_context.route, app_name=_request_context.app_name, multiplexed_model_id=_request_context.multiplexed_model_id, - is_streaming=self._stream, + is_streaming=handle_options.stream, ) self.request_counter.inc( tags={ @@ -273,7 +281,6 @@ def __reduce__(self): "deployment_name": self.deployment_name, "handle_options": self.handle_options, "_is_for_http_requests": self._is_for_http_requests, - "_stream": self._stream, } return RayServeHandle._deserialize, (serialized_data,) @@ -326,6 +333,7 @@ def options( *, method_name: Union[str, DEFAULT] = DEFAULT.VALUE, multiplexed_model_id: Union[str, DEFAULT] = DEFAULT.VALUE, + stream: Union[bool, DEFAULT] = DEFAULT.VALUE, ) -> "RayServeSyncHandle": """Set options for this handle and return an updated copy of it. @@ -340,7 +348,9 @@ def options( """ return self._options( - method_name=method_name, multiplexed_model_id=multiplexed_model_id + method_name=method_name, + multiplexed_model_id=multiplexed_model_id, + stream=stream, ) def remote(self, *args, **kwargs) -> ray.ObjectRef: @@ -367,7 +377,6 @@ def __reduce__(self): "deployment_name": self.deployment_name, "handle_options": self.handle_options, "_is_for_http_requests": self._is_for_http_requests, - "_stream": self._stream, } return RayServeSyncHandle._deserialize, (serialized_data,) @@ -388,17 +397,34 @@ def __init__( # requirement of serve.start; Thus handle is fulfilled at runtime. self.handle: RayServeHandle = None - def options(self, *, method_name: str) -> "RayServeDeploymentHandle": - return self.__class__( - self.deployment_name, HandleOptions(method_name=method_name) - ) + def options( + self, + *, + method_name: Union[str, DEFAULT] = DEFAULT.VALUE, + stream: Union[bool, DEFAULT] = DEFAULT.VALUE, + ) -> "RayServeDeploymentHandle": + new_options_dict = self.handle_options.__dict__.copy() + user_modified_options_dict = { + key: value + for key, value in [ + ("method_name", method_name), + ("stream", stream), + ] + if value != DEFAULT.VALUE + } + new_options_dict.update(user_modified_options_dict) + new_options = HandleOptions(**new_options_dict) + return self.__class__(self.deployment_name, new_options) def remote(self, *args, _ray_cache_refs: bool = False, **kwargs) -> asyncio.Task: if not self.handle: handle = serve._private.api.get_deployment( self.deployment_name )._get_handle(sync=FLAG_SERVE_DEPLOYMENT_HANDLE_IS_SYNC) - self.handle = handle.options(method_name=self.handle_options.method_name) + self.handle = handle.options( + method_name=self.handle_options.method_name, + stream=self.handle_options.stream, + ) return self.handle.remote(*args, **kwargs) @classmethod diff --git a/python/ray/serve/tests/test_handle_streaming.py b/python/ray/serve/tests/test_handle_streaming.py new file mode 100644 index 000000000000..78274a240127 --- /dev/null +++ b/python/ray/serve/tests/test_handle_streaming.py @@ -0,0 +1,239 @@ +import sys + +import pytest + +import ray + +from ray import serve +from ray.serve import Deployment +from ray.serve.handle import RayServeHandle +from ray.serve._private.constants import ( + RAY_SERVE_ENABLE_NEW_ROUTING, +) + + +@serve.deployment +class AsyncStreamer: + async def __call__(self, n: int, should_error: bool = False): + if should_error: + raise RuntimeError("oopsies") + + for i in range(n): + yield i + + async def other_method(self, n: int): + for i in range(n): + yield i + + async def unary(self, n: int): + return n + + +@serve.deployment +class SyncStreamer: + def __call__(self, n: int, should_error: bool = False): + if should_error: + raise RuntimeError("oopsies") + + for i in range(n): + yield i + + def other_method(self, n: int): + for i in range(n): + yield i + + def unary(self, n: int): + return n + + +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_NEW_ROUTING, reason="Routing FF must be enabled." +) +@pytest.mark.parametrize("deployment", [AsyncStreamer, SyncStreamer]) +class TestAppHandleStreaming: + def test_basic(self, serve_instance, deployment: Deployment): + h = serve.run(deployment.bind()).options(stream=True) + + # Test calling __call__ generator. + obj_ref_gen = ray.get(h.remote(5)) + assert ray.get(list(obj_ref_gen)) == list(range(5)) + + # Test calling another method name. + obj_ref_gen = ray.get(h.other_method.remote(5)) + assert ray.get(list(obj_ref_gen)) == list(range(5)) + + # Test calling another method name via `.options`. + obj_ref_gen = ray.get(h.options(method_name="other_method").remote(5)) + assert ray.get(list(obj_ref_gen)) == list(range(5)) + + # Test calling a unary method on the same deployment. + assert ray.get(h.options(stream=False).unary.remote(5)) == 5 + + def test_call_gen_without_stream_flag(self, serve_instance, deployment: Deployment): + h = serve.run(deployment.bind()) + + with pytest.raises( + TypeError, + match=( + "Method '__call__' is a generator. You must use " + "`handle.options\(stream=True\)` to call generator " + "methods on a deployment." + ), + ): + ray.get(h.remote()) + + def test_call_no_gen_with_stream_flag(self, serve_instance, deployment: Deployment): + h = serve.run(deployment.bind()).options(stream=True) + + obj_ref_gen = ray.get(h.unary.remote(0)) + with pytest.raises(TypeError, match="Method 'unary' is not a generator."): + ray.get(next(obj_ref_gen)) + + def test_generator_yields_no_results(self, serve_instance, deployment: Deployment): + h = serve.run(deployment.bind()).options(stream=True) + + obj_ref_gen = ray.get(h.remote(0)) + with pytest.raises(StopIteration): + ray.get(next(obj_ref_gen)) + + def test_exception_raised_in_gen(self, serve_instance, deployment: Deployment): + h = serve.run(deployment.bind()).options(stream=True) + + obj_ref_gen = ray.get(h.remote(0, should_error=True)) + with pytest.raises(RuntimeError, match="oopsies"): + ray.get(next(obj_ref_gen)) + + +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_NEW_ROUTING, reason="Routing FF must be enabled." +) +@pytest.mark.parametrize("deployment", [AsyncStreamer, SyncStreamer]) +class TestDeploymentHandleStreaming: + def test_basic(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer: RayServeHandle): + self._h = streamer + + async def __call__(self): + h = self._h.options(stream=True) + + # Test calling __call__ generator. + obj_ref_gen = await h.remote(5) + assert [await obj_ref async for obj_ref in obj_ref_gen] == list( + range(5) + ) + + # Test calling another method name. + obj_ref_gen = await h.other_method.remote(5) + assert [await obj_ref for obj_ref in obj_ref_gen] == list(range(5)) + + # Test calling another method name via `.options`. + obj_ref_gen = await h.options(method_name="other_method").remote(5) + assert [await obj_ref for obj_ref in obj_ref_gen] == list(range(5)) + + # Test calling a unary method on the same deployment. + assert await (await h.options(stream=False).unary.remote(5)) == 5 + + h = serve.run(Delegate.bind(deployment.bind())) + ray.get(h.remote()) + + def test_call_gen_without_stream_flag(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer: RayServeHandle): + self._h = streamer + + async def __call__(self): + with pytest.raises( + TypeError, + match=( + "Method '__call__' is a generator. You must use " + "`handle.options\(stream=True\)` to call generator " + "methods on a deployment." + ), + ): + await (await self._h.remote()) + + h = serve.run(Delegate.bind(deployment.bind())) + ray.get(h.remote()) + + def test_call_no_gen_with_stream_flag(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer: RayServeHandle): + self._h = streamer + + async def __call__(self): + h = self._h.options(stream=True) + + obj_ref_gen = await h.unary.remote(0) + with pytest.raises( + TypeError, match="Method 'unary' is not a generator." + ): + await (await obj_ref_gen.__anext__()) + + h = serve.run(Delegate.bind(deployment.bind())) + ray.get(h.remote()) + + def test_generator_yields_no_results(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer: RayServeHandle): + self._h = streamer + + async def __call__(self): + h = self._h.options(stream=True) + + obj_ref_gen = await h.remote(0) + with pytest.raises(StopAsyncIteration): + await (await obj_ref_gen.__anext__()) + + h = serve.run(Delegate.bind(deployment.bind())) + ray.get(h.remote()) + + def test_exception_raised_in_gen(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer: RayServeHandle): + self._h = streamer + + async def __call__(self): + h = self._h.options(stream=True) + + obj_ref_gen = await h.remote(0, should_error=True) + with pytest.raises(RuntimeError, match="oopsies"): + await (await obj_ref_gen.__anext__()) + + h = serve.run(Delegate.bind(deployment.bind())) + ray.get(h.remote()) + + def test_call_multiple_downstreams(self, serve_instance, deployment: Deployment): + @serve.deployment + class Delegate: + def __init__(self, streamer1: RayServeHandle, streamer2: RayServeHandle): + self._h1 = streamer1.options(stream=True) + self._h2 = streamer2.options(stream=True) + + async def __call__(self): + obj_ref_gen1 = await self._h1.remote(1) + obj_ref_gen2 = await self._h2.remote(2) + + assert await (await obj_ref_gen1.__anext__()) == 0 + assert await (await obj_ref_gen2.__anext__()) == 0 + + with pytest.raises(StopAsyncIteration): + assert await (await obj_ref_gen1.__anext__()) + assert await (await obj_ref_gen2.__anext__()) == 1 + + with pytest.raises(StopAsyncIteration): + assert await (await obj_ref_gen1.__anext__()) + with pytest.raises(StopAsyncIteration): + assert await (await obj_ref_gen2.__anext__()) + + h = serve.run(Delegate.bind(deployment.bind(), deployment.bind())) + ray.get(h.remote()) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", "-s", __file__])) diff --git a/python/ray/serve/tests/test_http_prefix_matching.py b/python/ray/serve/tests/test_http_prefix_matching.py index a4f16c88f901..d42a53248adf 100644 --- a/python/ray/serve/tests/test_http_prefix_matching.py +++ b/python/ray/serve/tests/test_http_prefix_matching.py @@ -6,8 +6,18 @@ @pytest.fixture def mock_longest_prefix_router() -> LongestPrefixRouter: + class MockHandle: + def __init__(self, name: str): + self._name = name + + def options(self, *args, **kwargs): + return self + + def __eq__(self, other_name: str): + return self._name == other_name + def mock_get_handle(name, *args, **kwargs): - return name + return MockHandle(name) yield LongestPrefixRouter(mock_get_handle) diff --git a/python/ray/serve/tests/test_metrics.py b/python/ray/serve/tests/test_metrics.py index e83890e09c14..f7e2137e0527 100644 --- a/python/ray/serve/tests/test_metrics.py +++ b/python/ray/serve/tests/test_metrics.py @@ -906,6 +906,7 @@ def metric_available() -> bool: metric_dict_str = f"dict({line[dict_body_start:dict_body_end]})" metric_dicts.append(eval(metric_dict_str)) + print(metric_dicts) return metric_dicts diff --git a/python/ray/serve/tests/test_streaming_response.py b/python/ray/serve/tests/test_streaming_response.py index f443c4b40324..c88ad9b9064e 100644 --- a/python/ray/serve/tests/test_streaming_response.py +++ b/python/ray/serve/tests/test_streaming_response.py @@ -12,6 +12,7 @@ from ray._private.test_utils import SignalActor from ray import serve +from ray.serve.handle import RayServeHandle from ray.serve._private.constants import RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING @@ -251,6 +252,72 @@ def __call__(self, request: Request) -> StreamingResponse: next(stream_iter) +@pytest.mark.skipif( + not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, + reason="Streaming feature flag is disabled.", +) +@pytest.mark.parametrize("use_fastapi", [False, True]) +@pytest.mark.parametrize("use_async", [False, True]) +def test_proxy_from_streaming_handle( + serve_instance, use_async: bool, use_fastapi: bool +): + @serve.deployment + class Streamer: + async def hi_gen_async(self): + for i in range(10): + yield f"hi_{i}" + + def hi_gen_sync(self): + for i in range(10): + yield f"hi_{i}" + + if use_fastapi: + app = FastAPI() + + @serve.deployment + @serve.ingress(app) + class SimpleGenerator: + def __init__(self, handle: RayServeHandle): + self._h = handle.options(stream=True) + + @app.get("/") + def stream_hi(self, request: Request) -> StreamingResponse: + async def consume_obj_ref_gen(): + if use_async: + obj_ref_gen = await self._h.hi_gen_async.remote() + else: + obj_ref_gen = await self._h.hi_gen_sync.remote() + async for obj_ref in obj_ref_gen: + yield await obj_ref + + return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain") + + else: + + @serve.deployment + class SimpleGenerator: + def __init__(self, handle: RayServeHandle): + self._h = handle.options(stream=True) + + def __call__(self, request: Request) -> StreamingResponse: + async def consume_obj_ref_gen(): + if use_async: + obj_ref_gen = await self._h.hi_gen_async.remote() + else: + obj_ref_gen = await self._h.hi_gen_sync.remote() + async for obj_ref in obj_ref_gen: + yield await obj_ref + + return StreamingResponse(consume_obj_ref_gen(), media_type="text/plain") + + serve.run(SimpleGenerator.bind(Streamer.bind())) + + r = requests.get("http://localhost:8000", stream=True) + r.raise_for_status() + for i, chunk in enumerate(r.iter_content(chunk_size=None, decode_unicode=True)): + assert chunk == f"hi_{i}" + + @pytest.mark.skipif( not RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING, reason="Streaming feature flag is disabled.", diff --git a/python/ray/serve/tests/test_util.py b/python/ray/serve/tests/test_util.py index fb95492bb752..1069a3d0e1ad 100644 --- a/python/ray/serve/tests/test_util.py +++ b/python/ray/serve/tests/test_util.py @@ -23,6 +23,7 @@ snake_to_camel_case, dict_keys_snake_to_camel_case, get_head_node_id, + wrap_generator_function_in_async_if_needed, ) from ray._private.resource_spec import HEAD_NODE_RESOURCE_NAME @@ -618,6 +619,41 @@ def test_calculate_remaining_timeout(): ) +@pytest.mark.asyncio +async def test_wrap_generator_function_in_async_if_needed(): + def regular_function(): + return "hi" + + with pytest.raises(TypeError): + wrap_generator_function_in_async_if_needed(regular_function) + + def sync_gen(): + for i in range(5): + yield i + + wrapped = wrap_generator_function_in_async_if_needed(sync_gen) + assert wrapped.__name__ == sync_gen.__name__ + + nums = [] + async for i in wrapped(): + nums.append(i) + + assert nums == list(range(5)) + + async def async_gen(): + for i in range(5): + yield i + + not_wrapped = wrap_generator_function_in_async_if_needed(async_gen) + assert not_wrapped == async_gen + + nums = [] + async for i in not_wrapped(): + nums.append(i) + + assert nums == list(range(5)) + + if __name__ == "__main__": import sys