diff --git a/python/ray/serve/_private/common.py b/python/ray/serve/_private/common.py index 69f0f570c769..c14493b488d3 100644 --- a/python/ray/serve/_private/common.py +++ b/python/ray/serve/_private/common.py @@ -311,7 +311,18 @@ def __post_init__(self): ] ) ) + + # RunningReplicaInfo class set frozen=True, this is the hacky way to set + # new attribute for the class. object.__setattr__(self, "_hash", hash_val) def __hash__(self): return self._hash + + def __eq__(self, other): + return all( + [ + isinstance(other, RunningReplicaInfo), + self._hash == other._hash, + ] + ) diff --git a/python/ray/serve/tests/test_router.py b/python/ray/serve/tests/test_router.py index 80520a064a84..422a245acfd7 100644 --- a/python/ray/serve/tests/test_router.py +++ b/python/ray/serve/tests/test_router.py @@ -3,6 +3,7 @@ controller or the actual replica wrapper, use mock if necessary. """ import asyncio +import copy import pytest @@ -118,6 +119,23 @@ async def num_queries(self): await asyncio.sleep(0.2) assert not third_ref_pending_task.done() + # Let's make sure in flight queries is 1 for each replica. + assert len(rs.in_flight_queries[replicas[0]]) == 1 + assert len(rs.in_flight_queries[replicas[1]]) == 1 + + # Let's copy a new RunningReplicaInfo object and update the router + cur_replicas_info = list(rs.in_flight_queries.keys()) + replicas = copy.deepcopy(cur_replicas_info) + assert id(replicas[0].actor_handle) != id(cur_replicas_info[0].actor_handle) + assert replicas[0].replica_tag == cur_replicas_info[0].replica_tag + assert id(replicas[1].actor_handle) != id(cur_replicas_info[1].actor_handle) + assert replicas[1].replica_tag == cur_replicas_info[1].replica_tag + rs.update_running_replicas(replicas) + + # Let's make sure in flight queries is 1 for each replica even if replicas update + assert len(rs.in_flight_queries[replicas[0]]) == 1 + assert len(rs.in_flight_queries[replicas[1]]) == 1 + # Let's unblock the two replicas await signal.send.remote() assert await first_ref == "DONE" diff --git a/python/ray/serve/tests/test_standalone2.py b/python/ray/serve/tests/test_standalone2.py index 35958d36670a..816b6242a0a5 100644 --- a/python/ray/serve/tests/test_standalone2.py +++ b/python/ray/serve/tests/test_standalone2.py @@ -828,6 +828,72 @@ async def waiter(*args): serve.shutdown() +@pytest.mark.parametrize( + "ray_instance", + [ + { + "LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_LOWER_BOUND": "1", + "LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S_UPPER_BOUND": "2", + }, + ], + indirect=True, +) +def test_long_poll_timeout_with_max_concurrent_queries(ray_instance): + """Test max_concurrent_queries can be honorded with long poll timeout + + issue: https://github.com/ray-project/ray/issues/32652 + """ + + signal_actor = SignalActor.remote() + + @serve.deployment(max_concurrent_queries=1) + async def f(): + await signal_actor.wait.remote() + return "hello" + + handle = serve.run(f.bind()) + first_ref = handle.remote() + + # Clear all the internal longpoll client objects within handle + # long poll client will receive new updates from long poll host, + # this is to simulate the longpoll timeout + object_snapshots1 = handle.router.long_poll_client.object_snapshots + handle.router.long_poll_client._reset() + wait_for_condition( + lambda: len(handle.router.long_poll_client.object_snapshots) > 0, timeout=10 + ) + object_snapshots2 = handle.router.long_poll_client.object_snapshots + + # Check object snapshots between timeout interval + assert object_snapshots1.keys() == object_snapshots2.keys() + assert len(object_snapshots1.keys()) == 1 + key = list(object_snapshots1.keys())[0] + assert ( + object_snapshots1[key][0].actor_handle != object_snapshots2[key][0].actor_handle + ) + assert ( + object_snapshots1[key][0].actor_handle._actor_id + == object_snapshots2[key][0].actor_handle._actor_id + ) + + # Make sure the inflight queries still one + assert len(handle.router._replica_set.in_flight_queries) == 1 + key = list(handle.router._replica_set.in_flight_queries.keys())[0] + assert len(handle.router._replica_set.in_flight_queries[key]) == 1 + + # Make sure the first request is being run. + replicas = list(handle.router._replica_set.in_flight_queries.keys()) + assert len(handle.router._replica_set.in_flight_queries[replicas[0]]) == 1 + # First ref should be still ongoing + with pytest.raises(ray.exceptions.GetTimeoutError): + ray.get(first_ref, timeout=1) + # Unblock the first request. + signal_actor.send.remote() + assert ray.get(first_ref) == "hello" + + serve.shutdown() + + def test_shutdown_remote(start_and_shutdown_ray_cli_function): """Check that serve.shutdown() works on a remote Ray cluster."""