Skip to content

Commit

Permalink
[serve] Enable deployment of functions/classes that take no parameters (
Browse files Browse the repository at this point in the history
  • Loading branch information
shrekris-anyscale authored Oct 28, 2021
1 parent ed0e2e4 commit 6e6fff8
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 3 deletions.
12 changes: 9 additions & 3 deletions python/ray/serve/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,15 @@ async def invoke_single(self, request_item: Query) -> Any:
start = time.time()
method_to_call = None
try:
method_to_call = sync_to_async(
self.get_runner_method(request_item))
result = await method_to_call(*args, **kwargs)
runner_method = self.get_runner_method(request_item)
method_to_call = sync_to_async(runner_method)
result = None
if len(inspect.signature(runner_method).parameters) > 0:
result = await method_to_call(*args, **kwargs)
else:
# The method doesn't take in anything, including the request
# information, so we pass nothing into it
result = await method_to_call()

result = await self.ensure_serializable_response(result)
self.request_counter.inc()
Expand Down
67 changes: 67 additions & 0 deletions python/ray/serve/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,73 @@ async def slow_numbers():
assert resp.status_code == 418


def test_deploy_sync_function_no_params(serve_instance):
@serve.deployment()
def sync_d():
return "sync!"

serve.start()

sync_d.deploy()
assert requests.get("http://localhost:8000/sync_d").text == "sync!"
assert ray.get(sync_d.get_handle().remote()) == "sync!"


def test_deploy_async_function_no_params(serve_instance):
@serve.deployment()
async def async_d():
await asyncio.sleep(5)
return "async!"

serve.start()

async_d.deploy()
assert requests.get("http://localhost:8000/async_d").text == "async!"
assert ray.get(async_d.get_handle().remote()) == "async!"


def test_deploy_sync_class_no_params(serve_instance):
@serve.deployment
class Counter:
def __init__(self):
self.count = 0

def __call__(self):
self.count += 1
return {"count": self.count}

serve.start()
Counter.deploy()

assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 1}
assert requests.get("http://127.0.0.1:8000/Counter").json() == {"count": 2}
assert ray.get(Counter.get_handle().remote()) == {"count": 3}


def test_deploy_async_class_no_params(serve_instance):
@serve.deployment
class AsyncCounter:
async def __init__(self):
await asyncio.sleep(5)
self.count = 0

async def __call__(self):
self.count += 1
await asyncio.sleep(5)
return {"count": self.count}

serve.start()
AsyncCounter.deploy()

assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == {
"count": 1
}
assert requests.get("http://127.0.0.1:8000/AsyncCounter").json() == {
"count": 2
}
assert ray.get(AsyncCounter.get_handle().remote()) == {"count": 3}


def test_user_config(serve_instance):
@serve.deployment(
"counter", num_replicas=2, user_config={
Expand Down

0 comments on commit 6e6fff8

Please sign in to comment.