Skip to content

Commit

Permalink
[Serve] Decrement ray_serve_deployment_queued_queries when client d…
Browse files Browse the repository at this point in the history
…isconnects (#37965) (#38020)

The `ray_serve_deployment_queued_queries` metric tracks the number of queries that have yet to be assigned a replica. If a client disconnects before its query has been assigned a replica– but after the metric has counted their query– the query terminates, but the metric doesn't decrease.

This change decrements `ray_serve_deployment_queued_queries` when a queued request is disconnected.

Signed-off-by: Edward Oakes <[email protected]>
Co-authored-by: shrekris-anyscale <[email protected]>
  • Loading branch information
edoakes and shrekris-anyscale authored Aug 8, 2023
1 parent 23528ba commit 1945d8f
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 17 deletions.
37 changes: 21 additions & 16 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,20 +929,25 @@ async def assign_request(
},
)

query = Query(
args=list(request_args),
kwargs=request_kwargs,
metadata=request_meta,
)
await query.resolve_async_tasks()
result = await self._replica_scheduler.assign_replica(query)

self.num_queued_queries -= 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
try:
query = Query(
args=list(request_args),
kwargs=request_kwargs,
metadata=request_meta,
)
await query.resolve_async_tasks()
result = await self._replica_scheduler.assign_replica(query)

return result
return result
finally:
# If the query is disconnected before assignment, this coroutine
# gets cancelled by the caller and an asyncio.CancelledError is
# raised. The finally block ensures that num_queued_queries
# is correctly decremented in this case.
self.num_queued_queries -= 1
self.num_queued_queries_gauge.set(
self.num_queued_queries,
tags={
"application": request_meta.app_name,
},
)
61 changes: 60 additions & 1 deletion python/ray/serve/tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Dict, DefaultDict

import requests
Expand Down Expand Up @@ -841,6 +843,63 @@ def verify_metrics():
)


def test_queued_queries_disconnected(serve_start_shutdown):
"""Check that queued_queries decrements when queued requests disconnect."""

signal = SignalActor.remote()

@serve.deployment(
max_concurrent_queries=1,
graceful_shutdown_timeout_s=0.0001,
)
async def hang_on_first_request():
await signal.wait.remote()

serve.run(hang_on_first_request.bind())

print("Deployed hang_on_first_request deployment.")

def queue_size() -> float:
metrics = requests.get("http://127.0.0.1:9999").text
queue_size = -1
for line in metrics.split("\n"):
if "ray_serve_deployment_queued_queries" in line:
queue_size = line.split(" ")[-1]

return float(queue_size)

def first_request_executing(request_future) -> bool:
try:
request_future.get(timeout=0.1)
except Exception:
return ray.get(signal.cur_num_waiters.remote()) == 1

url = "http://localhost:8000/"

pool = Pool()

# Make a request to block the deployment from accepting other requests
fut = pool.apply_async(partial(requests.get, url))
wait_for_condition(lambda: first_request_executing(fut), timeout=5)
print("Executed first request.")

num_requests = 5
for _ in range(num_requests):
pool.apply_async(partial(requests.get, url))
print(f"Executed {num_requests} more requests.")

# First request should be processing. All others should be queued.
wait_for_condition(lambda: queue_size() == num_requests, timeout=15)
print("ray_serve_deployment_queued_queries updated successfully.")

# Disconnect all requests by terminating the process pool.
pool.terminate()
print("Terminated all requests.")

wait_for_condition(lambda: queue_size() == 0, timeout=15)
print("ray_serve_deployment_queued_queries updated successfully.")


def test_actor_summary(serve_instance):
@serve.deployment
def f():
Expand All @@ -855,7 +914,7 @@ def f():


def get_metric_dictionaries(name: str, timeout: float = 20) -> List[Dict]:
"""Gets a list of metric's dictionaries from metrics' text output.
"""Gets a list of metric's tags from metrics' text output.
Return:
Example:
Expand Down

0 comments on commit 1945d8f

Please sign in to comment.