Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] immediately send ping in router when receiving new replica set #47053

Merged
merged 12 commits into from
Aug 14, 2024
5 changes: 4 additions & 1 deletion python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import ray
from ray import cloudpickle
from ray._private.utils import get_or_create_event_loop
from ray.actor import ActorClass
from ray.actor import ActorClass, ActorHandle
from ray.remote_function import RemoteFunction
from ray.serve import metrics
from ray.serve._private.common import (
Expand Down Expand Up @@ -321,6 +321,9 @@ def _configure_logger_and_profilers(
component_id=self._component_id,
)

def push_proxy_handle(self, handle: ActorHandle):
pass

def get_num_ongoing_requests(self) -> int:
"""Fetch the number of ongoing requests at this replica (queue length).

Expand Down
8 changes: 8 additions & 0 deletions python/ray/serve/_private/replica_scheduler/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import ray
from ray import ObjectRef, ObjectRefGenerator
from ray.actor import ActorHandle
from ray.serve._private.common import (
ReplicaID,
ReplicaQueueLengthInfo,
Expand Down Expand Up @@ -58,6 +59,10 @@ def max_ongoing_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
pass

def push_proxy_handle(self, handle: ActorHandle):
"""When on proxy, push proxy's self handle to replica"""
pass

async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.

Expand Down Expand Up @@ -120,6 +125,9 @@ def max_ongoing_requests(self) -> int:
def is_cross_language(self) -> bool:
return self._replica_info.is_cross_language

def push_proxy_handle(self, handle: ActorHandle):
self._actor_handle.push_proxy_handle.remote(handle)

async def get_queue_len(self, *, deadline_s: float) -> int:
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
Expand Down
23 changes: 23 additions & 0 deletions python/ray/serve/_private/replica_scheduler/pow_2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
Tuple,
)

import ray
from ray.actor import ActorHandle
from ray.exceptions import ActorDiedError, ActorUnavailableError
from ray.serve._private.common import (
DeploymentHandleSource,
DeploymentID,
ReplicaID,
RequestMetadata,
Expand Down Expand Up @@ -89,19 +92,23 @@ def __init__(
self,
event_loop: asyncio.AbstractEventLoop,
deployment_id: DeploymentID,
handle_source: DeploymentHandleSource,
prefer_local_node_routing: bool = False,
prefer_local_az_routing: bool = False,
self_node_id: Optional[str] = None,
self_actor_id: Optional[str] = None,
self_actor_handle: Optional[ActorHandle] = None,
self_availability_zone: Optional[str] = None,
use_replica_queue_len_cache: bool = False,
get_curr_time_s: Optional[Callable[[], float]] = None,
):
self._loop = event_loop
self._deployment_id = deployment_id
self._handle_source = handle_source
self._prefer_local_node_routing = prefer_local_node_routing
self._prefer_local_az_routing = prefer_local_az_routing
self._self_node_id = self_node_id
self._self_actor_handle = self_actor_handle
self._self_availability_zone = self_availability_zone
self._use_replica_queue_len_cache = use_replica_queue_len_cache

Expand Down Expand Up @@ -240,7 +247,17 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
new_replica_id_set = set()
new_colocated_replica_ids = defaultdict(set)
new_multiplexed_model_id_to_replica_ids = defaultdict(set)

for r in replicas:
# If on the proxy, replica needs to call back into the proxy with
# `receive_asgi_messages` which can be blocked when GCS is down.
# To prevent that from happening, push proxy handle eagerly
if (
self._handle_source == DeploymentHandleSource.PROXY
and r.replica_id not in self._replicas
):
r.push_proxy_handle(self._self_actor_handle)

new_replicas[r.replica_id] = r
new_replica_id_set.add(r.replica_id)
if self._self_node_id is not None and r.node_id == self._self_node_id:
Expand All @@ -263,6 +280,10 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
extra={"log_to_stderr": False},
)

# Get list of new replicas
new_ids = new_replica_id_set - self._replica_id_set
replicas_to_ping = [new_replicas.get(id) for id in new_ids]

self._replicas = new_replicas
self._replica_id_set = new_replica_id_set
self._colocated_replica_ids = new_colocated_replica_ids
Expand All @@ -272,6 +293,8 @@ def update_replicas(self, replicas: List[ReplicaWrapper]):
self._replica_queue_len_cache.remove_inactive_replicas(
active_replica_ids=new_replica_id_set
)
# Populate cache for new replicas
self._loop.create_task(self._probe_queue_lens(replicas_to_ping, 0))
self._replicas_updated_event.set()
self.maybe_start_scheduling_tasks()

Expand Down
4 changes: 4 additions & 0 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,14 @@ def __init__(
replica_scheduler = PowerOfTwoChoicesReplicaScheduler(
self._event_loop,
deployment_id,
handle_source,
_prefer_local_node_routing,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
self_node_id,
self_actor_id,
ray.get_runtime_context().current_actor
if ray.get_runtime_context().get_actor_id()
else None,
self_availability_zone,
use_replica_queue_len_cache=enable_queue_len_cache,
)
Expand Down
10 changes: 10 additions & 0 deletions python/ray/serve/_private/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import datetime
import os
import threading
import time
from copy import copy, deepcopy
Expand Down Expand Up @@ -197,6 +198,15 @@ def __init__(
self._soft_target_node_id = _soft_target_node_id


@serve.deployment
class GetPID:
def __call__(self):
return os.getpid()


get_pid_entrypoint = GetPID.bind()


def check_ray_stopped():
try:
requests.get("http://localhost:52365/api/ray/version")
Expand Down
138 changes: 138 additions & 0 deletions python/ray/serve/tests/test_gcs_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,17 @@
import ray
from ray import serve
from ray._private.test_utils import wait_for_condition
from ray.serve._private.common import DeploymentID, ReplicaState
from ray.serve._private.constants import SERVE_DEFAULT_APP_NAME
from ray.serve._private.storage.kv_store import KVStoreError, RayInternalKVStore
from ray.serve._private.test_utils import (
GetPID,
check_apps_running,
check_replica_counts,
)
from ray.serve.context import _get_global_client
from ray.serve.handle import DeploymentHandle
from ray.serve.schema import ServeDeploySchema
from ray.tests.conftest import external_redis # noqa: F401


Expand All @@ -27,6 +35,8 @@ def serve_ha(external_redis, monkeypatch): # noqa: F811
serve.start()
yield (address_info, _get_global_client())
ray.shutdown()
# Clear cache and global serve client
serve.shutdown()


@pytest.mark.skipif(
Expand Down Expand Up @@ -105,6 +115,134 @@ def call():
assert pid == call()


def router_populated_with_replicas(handle: DeploymentHandle, threshold: int = 1):
replicas = handle._router._replica_scheduler._replica_id_set
assert len(replicas) >= threshold
return True


@pytest.mark.parametrize("use_proxy", [True, False])
def test_new_router_on_gcs_failure(serve_ha, use_proxy: bool):
"""Test that a new router can send requests to replicas when GCS is down.

Specifically, if a proxy was just brought up or a deployment handle
was just created, and the GCS goes down BEFORE the router is able to
send its first request, new incoming requests should successfully get
sent to replicas during GCS downtime.
"""

@serve.deployment
class Dummy:
def __call__(self):
return os.getpid()

h = serve.run(Dummy.options(num_replicas=2).bind())
# TODO(zcin): We want to test the behavior for when the router
# didn't get a chance to send even a single request yet. However on
# the very first request we record telemetry for whether the
# deployment handle API was used, which will hang when the GCS is
# down. As a workaround for now, avoid recording telemetry so we
# can properly test router behavior when GCS is down. We should look
# into adding a timeout on the kv cache operation. For now, the proxy
# doesn't run into this because we don't record telemetry on proxy
h._recorded_telemetry = True
# Eagerly create router so it receives the replica set instead of
# waiting for the first request
h._get_or_create_router()

wait_for_condition(router_populated_with_replicas, handle=h)

# Kill GCS server before a single request is sent.
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
if use_proxy:
for _ in range(10):
returned_pids.add(
int(requests.get("http://localhost:8000", timeout=0.1).text)
)
else:
for _ in range(10):
returned_pids.add(int(h.remote().result(timeout_s=0.1)))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


def test_handle_router_updated_replicas_then_gcs_failure(serve_ha):
_, client = serve_ha

config = {
"name": "default",
"import_path": "ray.serve._private.test_utils:get_pid_entrypoint",
"route_prefix": "/",
"deployments": [{"name": "GetPID", "num_replicas": 1}],
}
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))
wait_for_condition(check_apps_running, apps=["default"])

h = serve.get_app_handle("default")
print(h.remote().result())

config["deployments"][0]["num_replicas"] = 2
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))

wait_for_condition(router_populated_with_replicas, handle=h, threshold=2)

# Kill GCS server before router gets to send request to second replica
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
for _ in range(10):
returned_pids.add(int(h.remote().result(timeout_s=0.1)))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


def test_proxy_router_updated_replicas_then_gcs_failure(serve_ha):
_, client = serve_ha

config = {
"name": "default",
"import_path": "ray.serve._private.test_utils:get_pid_entrypoint",
"route_prefix": "/",
"deployments": [{"name": "GetPID", "num_replicas": 1}],
}
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))
wait_for_condition(check_apps_running, apps=["default"])

r = requests.post("http://localhost:8000")
assert r.status_code == 200, r.text
print(r.text)

config["deployments"][0]["num_replicas"] = 2
client.deploy_apps(ServeDeploySchema(**{"applications": [config]}))

# There is no way to directly check if proxy has received updated replicas,
# so just check for the status. After controller updates status with new
# replicas, proxy should instantly receive updates from long poll
wait_for_condition(
check_replica_counts,
controller=client._controller,
deployment_id=DeploymentID("GetPID", "default"),
total=2,
by_state=[(ReplicaState.RUNNING, 2, None)],
)

# Kill GCS server before router gets to send request to second replica
ray.worker._global_node.kill_gcs_server()

returned_pids = set()
for _ in range(10):
r = requests.post("http://localhost:8000")
assert r.status_code == 200
returned_pids.add(int(r.text))

print("Returned pids:", returned_pids)
assert len(returned_pids) == 2


if __name__ == "__main__":
# When GCS is down, right now some core worker members are not cleared
# properly in ray.shutdown. Given that this is not hi-pri issue,
Expand Down
Loading