Skip to content

Commit

Permalink
[Serve] Fix the max_concurrent_queries issue (#33022)
Browse files Browse the repository at this point in the history
For the `hashable` object, __eq__ and __hash__ both need to be provided for correctness.  https://docs.python.org/3.9/glossary.html#term-hashable

And add tests to make sure the long poll timeout issue won't happen.
  • Loading branch information
sihanwang41 committed Mar 7, 2023
1 parent b16b4e1 commit d8efee2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
11 changes: 11 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,18 @@ def __post_init__(self):
]
)
)

# RunningReplicaInfo class set frozen=True, this is the hacky way to set
# new attribute for the class.
object.__setattr__(self, "_hash", hash_val)

def __hash__(self):
return self._hash

def __eq__(self, other):
return all(
[
isinstance(other, RunningReplicaInfo),
self._hash == other._hash,
]
)
18 changes: 18 additions & 0 deletions python/ray/serve/tests/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
controller or the actual replica wrapper, use mock if necessary.
"""
import asyncio
import copy

import pytest

Expand Down Expand Up @@ -118,6 +119,23 @@ async def num_queries(self):
await asyncio.sleep(0.2)
assert not third_ref_pending_task.done()

# Let's make sure in flight queries is 1 for each replica.
assert len(rs.in_flight_queries[replicas[0]]) == 1
assert len(rs.in_flight_queries[replicas[1]]) == 1

# Let's copy a new RunningReplicaInfo object and update the router
cur_replicas_info = list(rs.in_flight_queries.keys())
replicas = copy.deepcopy(cur_replicas_info)
assert id(replicas[0].actor_handle) != id(cur_replicas_info[0].actor_handle)
assert replicas[0].replica_tag == cur_replicas_info[0].replica_tag
assert id(replicas[1].actor_handle) != id(cur_replicas_info[1].actor_handle)
assert replicas[1].replica_tag == cur_replicas_info[1].replica_tag
rs.update_running_replicas(replicas)

# Let's make sure in flight queries is 1 for each replica even if replicas update
assert len(rs.in_flight_queries[replicas[0]]) == 1
assert len(rs.in_flight_queries[replicas[1]]) == 1

# Let's unblock the two replicas
await signal.send.remote()
assert await first_ref == "DONE"
Expand Down
66 changes: 66 additions & 0 deletions python/ray/serve/tests/test_standalone2.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,72 @@ async def waiter(*args):
serve.shutdown()


@pytest.mark.parametrize(
"ray_instance",
[
{
"LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_LOWER_BOUND": "1",
"LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_UPPER_BOUND": "2",
},
],
indirect=True,
)
def test_long_poll_timeout_with_max_concurrent_queries(ray_instance):
"""Test max_concurrent_queries can be honorded with long poll timeout
issue: https://github.com/ray-project/ray/issues/32652
"""

signal_actor = SignalActor.remote()

@serve.deployment(max_concurrent_queries=1)
async def f():
await signal_actor.wait.remote()
return "hello"

handle = serve.run(f.bind())
first_ref = handle.remote()

# Clear all the internal longpoll client objects within handle
# long poll client will receive new updates from long poll host,
# this is to simulate the longpoll timeout
object_snapshots1 = handle.router.long_poll_client.object_snapshots
handle.router.long_poll_client._reset()
wait_for_condition(
lambda: len(handle.router.long_poll_client.object_snapshots) > 0, timeout=10
)
object_snapshots2 = handle.router.long_poll_client.object_snapshots

# Check object snapshots between timeout interval
assert object_snapshots1.keys() == object_snapshots2.keys()
assert len(object_snapshots1.keys()) == 1
key = list(object_snapshots1.keys())[0]
assert (
object_snapshots1[key][0].actor_handle != object_snapshots2[key][0].actor_handle
)
assert (
object_snapshots1[key][0].actor_handle._actor_id
== object_snapshots2[key][0].actor_handle._actor_id
)

# Make sure the inflight queries still one
assert len(handle.router._replica_set.in_flight_queries) == 1
key = list(handle.router._replica_set.in_flight_queries.keys())[0]
assert len(handle.router._replica_set.in_flight_queries[key]) == 1

# Make sure the first request is being run.
replicas = list(handle.router._replica_set.in_flight_queries.keys())
assert len(handle.router._replica_set.in_flight_queries[replicas[0]]) == 1
# First ref should be still ongoing
with pytest.raises(ray.exceptions.GetTimeoutError):
ray.get(first_ref, timeout=1)
# Unblock the first request.
signal_actor.send.remote()
assert ray.get(first_ref) == "hello"

serve.shutdown()


def test_shutdown_remote(start_and_shutdown_ray_cli_function):
"""Check that serve.shutdown() works on a remote Ray cluster."""

Expand Down

0 comments on commit d8efee2

Please sign in to comment.