From 6e6fff8857d2a87ee1ef50b3126eddee2b41d5c9 Mon Sep 17 00:00:00 2001 From: shrekris-anyscale <92341594+shrekris-anyscale@users.noreply.github.com> Date: Thu, 28 Oct 2021 10:53:44 -0700 Subject: [PATCH] [serve] Enable deployment of functions/classes that take no parameters (#19708) --- python/ray/serve/replica.py | 12 ++++-- python/ray/serve/tests/test_api.py | 67 ++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/python/ray/serve/replica.py b/python/ray/serve/replica.py index b23c37e7db91..364ae0364044 100644 --- a/python/ray/serve/replica.py +++ b/python/ray/serve/replica.py @@ -287,9 +287,15 @@ async def invoke_single(self, request_item: Query) -> Any: start = time.time() method_to_call = None try: - method_to_call = sync_to_async( - self.get_runner_method(request_item)) - result = await method_to_call(*args, **kwargs) + runner_method = self.get_runner_method(request_item) + method_to_call = sync_to_async(runner_method) + result = None + if len(inspect.signature(runner_method).parameters) > 0: + result = await method_to_call(*args, **kwargs) + else: + # The method doesn't take in anything, including the request + # information, so we pass nothing into it + result = await method_to_call() result = await self.ensure_serializable_response(result) self.request_counter.inc() diff --git a/python/ray/serve/tests/test_api.py b/python/ray/serve/tests/test_api.py index e6199e01d91d..fdde1f27b9dd 100644 --- a/python/ray/serve/tests/test_api.py +++ b/python/ray/serve/tests/test_api.py @@ -85,6 +85,73 @@ async def slow_numbers(): assert resp.status_code == 418 +def test_deploy_sync_function_no_params(serve_instance): + @serve.deployment() + def sync_d(): + return "sync!" + + serve.start() + + sync_d.deploy() + assert requests.get("http://localhost:8000/sync_d").text == "sync!" + assert ray.get(sync_d.get_handle().remote()) == "sync!" + + +def test_deploy_async_function_no_params(serve_instance): + @serve.deployment() + async def async_d(): + await asyncio.sleep(5) + return "async!" + + serve.start() + + async_d.deploy() + assert requests.get("http://localhost:8000/async_d").text == "async!" + assert ray.get(async_d.get_handle().remote()) == "async!" + + +def test_deploy_sync_class_no_params(serve_instance): + @serve.deployment + class Counter: + def __init__(self): + self.count = 0 + + def __call__(self): + self.count += 1 + return {"count": self.count} + + serve.start() + Counter.deploy() + + assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1} + assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 2} + assert ray.get(Counter.get_handle().remote()) == {"count": 3} + + +def test_deploy_async_class_no_params(serve_instance): + @serve.deployment + class AsyncCounter: + async def __init__(self): + await asyncio.sleep(5) + self.count = 0 + + async def __call__(self): + self.count += 1 + await asyncio.sleep(5) + return {"count": self.count} + + serve.start() + AsyncCounter.deploy() + + assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == { + "count": 1 + } + assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == { + "count": 2 + } + assert ray.get(AsyncCounter.get_handle().remote()) == {"count": 3} + + def test_user_config(serve_instance): @serve.deployment( "counter", num_replicas=2, user_config={