Skip to content

Commit

Permalink
[serve] fix lightweight update max ongoing requests
Browse files Browse the repository at this point in the history
When a lightweight update occurs for a deployment and `max_ongoing_requests` is updated, two components need to be notified:
1. Deployment handles, to know not to send more requests to a replica when it's reached its maximum
2. Replicas, to know to reject requests when it's reached its maximum

Right now we handle (1), but we don't handle (2), i.e. replicas aren't notified of the updated `max_ongoing_requests` for lightweight updates. The problem is that (1) is not strict enforcement of `max_ongoing_requests` since it relies on a cache that can be stale, so the current bug is that replicas aren't updated -> updated max is not fully enforced.

This PR fixes that, and updates a test to fully test this behavior.

Fixes #44975.


Signed-off-by: Cindy Zhang <[email protected]>
  • Loading branch information
zcin committed Apr 27, 2024
1 parent bac5d5c commit fe48404
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
2 changes: 1 addition & 1 deletion python/ray/serve/_private/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class DeploymentConfig(BaseModel):
)
max_ongoing_requests: PositiveInt = Field(
default=DEFAULT_MAX_ONGOING_REQUESTS,
update_type=DeploymentOptionUpdateType.NeedsReconfigure,
update_type=DeploymentOptionUpdateType.NeedsActorReconfigure,
)
max_queued_requests: int = Field(
default=-1,
Expand Down
3 changes: 3 additions & 0 deletions python/ray/serve/tests/test_config_files/get_signal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import ray
from ray import serve

Expand All @@ -7,6 +9,7 @@ class A:
async def __call__(self):
signal = ray.get_actor("signal123")
await signal.wait.remote()
return os.getpid()


app = A.bind()
44 changes: 29 additions & 15 deletions python/ray/serve/tests/test_deploy_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,40 +687,54 @@ def test_update_config_max_ongoing_requests(
):
"""Check that replicas stay alive when max_ongoing_requests is updated."""

signal = SignalActor.options(name="signal123").remote()

max_ongoing_requests_field_name = (
"max_concurrent_queries"
if use_max_concurrent_queries
else "max_ongoing_requests"
)
config_template = {
"import_path": "ray.serve.tests.test_config_files.pid.node",
"deployments": [{"name": "f"}],
"import_path": "ray.serve.tests.test_config_files.get_signal.app",
"deployments": [{"name": "A"}],
}
config_template["deployments"][0][max_ongoing_requests_field_name] = 1000

# Deploy first time, max_concurent_queries set to 1000.
client.deploy_apps(ServeDeploySchema.parse_obj({"applications": [config_template]}))
wait_for_condition(check_running, timeout=15)

all_replicas = ray.get(client._controller._all_running_replicas.remote())
assert len(all_replicas) == 1
assert all_replicas[list(all_replicas.keys())[0]][0].max_ongoing_requests == 1000

handle = serve.get_app_handle(SERVE_DEFAULT_APP_NAME)

# Send 10 requests. All of them should be sent to the replica immediately,
# but the requests should be blocked waiting for the signal
refs = [handle.remote() for _ in range(10)]
pids1 = {ref.result()[0] for ref in refs}
assert len(pids1) == 1
wait_for_condition(
lambda: ray.get(signal.cur_num_waiters.remote()) == 10, timeout=2
)

# Redeploy with max concurrent queries set to 2.
config_template["deployments"][0][max_ongoing_requests_field_name] = 2
signal.send.remote()
pids = {ref.result() for ref in refs}
assert len(pids) == 1
pid1 = pids.pop()

# Reset for redeployment
signal.send.remote(clear=True)
# Redeploy with max concurrent queries set to 5
config_template["deployments"][0][max_ongoing_requests_field_name] = 5
client.deploy_apps(ServeDeploySchema.parse_obj({"applications": [config_template]}))
wait_for_condition(check_running, timeout=15)
wait_for_condition(check_running, timeout=2)

# Verify that the PID of the replica didn't change.
# Send 10 requests. Only 5 of them should be sent to the replica
# immediately, and the remaining 5 should queue at the handle.
refs = [handle.remote() for _ in range(10)]
pids2 = {ref.result()[0] for ref in refs}
assert pids2 == pids1
with pytest.raises(RuntimeError):
wait_for_condition(
lambda: ray.get(signal.cur_num_waiters.remote()) > 5, timeout=2
)

signal.send.remote()
pids = {ref.result() for ref in refs}
assert pids == {pid1}


def test_update_config_health_check_period(client: ServeControllerClient):
Expand Down
6 changes: 5 additions & 1 deletion python/ray/serve/tests/unit/test_deployment_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,11 @@ def test_deploy_new_config_same_code_version(
)
check_counts(ds, total=1, by_state=[(ReplicaState.RUNNING, 1, v1)])

if option in ["user_config", "graceful_shutdown_wait_loop_s"]:
if option in [
"user_config",
"graceful_shutdown_wait_loop_s",
"max_ongoing_requests",
]:
dsm.update()
check_counts(ds, total=1)
check_counts(
Expand Down

0 comments on commit fe48404

Please sign in to comment.