Skip to content

Commit

Permalink
[serve] Add replica queue length caching to replica scheduler (#42943)
Browse files Browse the repository at this point in the history
Adds caching logic to avoid actively probing replicas for every request. This is integrated into the existing PowerOfTwoChoicesReplicaScheduler so it can reuse much of the same policy and mechanism (e.g., locality-aware and model multiplexing-aware candidate selection).


The benefits of this change are:

- Enables strict enforcement of max_concurrent_queries.
- Reduces proxy-side overhead for scheduling requests.
- Reduces latency for scheduling requests (in the "happy path," there's no extra RTT).


The changes are as follows:

- All calls to replicas are now streaming calls, and the first message returned is a system message. The replica uses this message to return its current queue length and reject requests if it's at capacity (max_concurrent_queries). If the replica rejects, the request scheduling procedure will be retried.
- The replica scheduler maintains a local cache of replica queue lengths. Entries in this cache have a timeout (currently set to 10 seconds). The cache is updated by (1) actively probing replicas and (2) the system response messages mentioned above.
- When scheduling a request, we first attempt to choose the best replica based on the queue lengths in the cache. If none of the candidates have entries in the cache that are below max_concurrent_queries, we fall back to active probing (as before this PR).


There are two feature flags introduced to control this behavior (both currently off by default):

- `RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE`
- `RAY_SERVE_ENABLE_STRICT_MAX_CONCURRENT_QUERIES` (implicitly set by the above)

---------

Signed-off-by: Edward Oakes <[email protected]>
  • Loading branch information
edoakes authored Feb 7, 2024
1 parent 372e71e commit d8b0fe9
Show file tree
Hide file tree
Showing 11 changed files with 1,024 additions and 288 deletions.
6 changes: 6 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,9 @@ class TargetCapacityDirection(str, Enum):

UP = "UP"
DOWN = "DOWN"


@dataclass(frozen=True)
class ReplicaQueueLengthInfo:
accepted: bool
num_ongoing_requests: int
18 changes: 18 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@
os.environ.get("RAY_SERVE_MAX_QUEUE_LENGTH_RESPONSE_DEADLINE_S", 1.0)
)

# Feature flag for caching queue lengths for faster routing in each handle.
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE = (
os.environ.get("RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE", "0") == "1"
)

# Feature flag for strictly enforcing max_concurrent_queries (replicas will reject
# requests).
RAY_SERVE_ENABLE_STRICT_MAX_CONCURRENT_QUERIES = (
os.environ.get("RAY_SERVE_ENABLE_STRICT_MAX_CONCURRENT_QUERIES", "0") == "1"
# Strict enforcement path must be enabled for the queue length cache.
or RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE
)

# Length of time to respect entries in the queue length cache when scheduling requests.
RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S = float(
os.environ.get("RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S", 10.0)
)

# The default autoscaling policy to use if none is specified.
DEFAULT_AUTOSCALING_POLICY = "ray.serve.autoscaling_policy:default_autoscaling_policy"

Expand Down
58 changes: 57 additions & 1 deletion python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.serve._private.common import (
DeploymentID,
ReplicaName,
ReplicaQueueLengthInfo,
ReplicaTag,
RequestMetadata,
ServeComponentType,
Expand Down Expand Up @@ -407,7 +408,7 @@ async def handle_request(
*request_args,
**request_kwargs,
) -> Tuple[bytes, Any]:
"""Entrypoint for all `stream=False` calls."""
"""Entrypoint for `stream=False` calls."""
request_metadata = pickle.loads(pickled_request_metadata)
with self._wrap_user_method_call(request_metadata):
return await self._user_callable_wrapper.call_user_method(
Expand Down Expand Up @@ -497,6 +498,61 @@ async def handle_request_streaming(
):
yield result

async def handle_request_with_rejection(
self,
pickled_request_metadata: bytes,
*request_args,
**request_kwargs,
) -> AsyncGenerator[Any, None]:
"""Entrypoint for all requests with strict max_concurrent_queries enforcement.
The first response from this generator is always a system message indicating
if the request was accepted (the replica has capacity for the request) or
rejected (the replica is already at max_concurrent_queries).
For non-streaming requests, there will only be one more message, the unary
result of the user request handler.
For streaming requests, the subsequent messages will be the results of the
user request handler (which must be a generator).
"""
request_metadata = pickle.loads(pickled_request_metadata)
limit = self._deployment_config.max_concurrent_queries
num_ongoing_requests = self.get_num_ongoing_requests()
if num_ongoing_requests >= limit:
logger.warning(
f"Replica at capacity of max_concurrent_queries={limit}, "
f"rejecting request {request_metadata.request_id}."
)
yield pickle.dumps(
ReplicaQueueLengthInfo(
accepted=False, num_ongoing_requests=num_ongoing_requests
)
)
return

with self._wrap_user_method_call(request_metadata):
yield pickle.dumps(
ReplicaQueueLengthInfo(
accepted=True,
# NOTE(edoakes): `_wrap_user_method_call` will increment the number
# of ongoing requests to include this one, so re-fetch the value.
num_ongoing_requests=self.get_num_ongoing_requests(),
)
)

if request_metadata.is_streaming:
async for result in self._call_user_generator(
request_metadata,
request_args,
request_kwargs,
):
yield result
else:
yield await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def handle_request_from_java(
self,
proto_request_metadata: bytes,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/replica_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ray.serve._private.replica_scheduler.common import ( # noqa: F401
Query,
PendingRequest,
ReplicaScheduler,
ReplicaWrapper,
)
Expand Down
183 changes: 144 additions & 39 deletions python/ray/serve/_private/replica_scheduler/common.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import asyncio
import logging
import pickle
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import ray
from ray.serve._private.common import RequestMetadata, RunningReplicaInfo
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray import ObjectRef, ObjectRefGenerator
from ray.serve._private.common import (
ReplicaQueueLengthInfo,
RequestMetadata,
RunningReplicaInfo,
)
from ray.serve._private.constants import (
RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
SERVE_LOGGER_NAME,
)
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto

logger = logging.getLogger(SERVE_LOGGER_NAME)


@dataclass
class Query:
@dataclass(frozen=True)
class PendingRequest:
args: List[Any]
kwargs: Dict[Any, Any]
metadata: RequestMetadata
created_at: float = field(default_factory=time.time)
future: asyncio.Future = field(default_factory=lambda: asyncio.Future())

def __eq__(self, other: Any) -> bool:
"""Request ID is expected to be unique."""
if isinstance(other, PendingRequest):
return self.metadata.request_id == other.metadata.request_id

return False


class ReplicaWrapper(ABC):
Expand All @@ -37,17 +55,20 @@ def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
pass

async def get_queue_state(self, *, deadline_s: float) -> Tuple[int, bool]:
"""Returns tuple of (queue_len, accepted).
@property
def max_concurrent_requests(self) -> int:
"""Max concurrent requests that can be sent to this replica."""
pass

async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.
`deadline_s` is passed to verify backoff for testing.
"""
pass

def send_query(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
"""Send query to this replica."""
def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
"""Send request to this replica."""
pass


Expand Down Expand Up @@ -77,71 +98,150 @@ def availability_zone(self) -> Optional[str]:
def multiplexed_model_ids(self) -> Set[str]:
return self._multiplexed_model_ids

async def get_queue_state(self, *, deadline_s: float) -> Tuple[int, bool]:
@property
def max_concurrent_requests(self) -> int:
return self._replica_info.max_concurrent_queries

@property
def is_cross_language(self) -> bool:
return self._replica_info.is_cross_language

async def get_queue_len(self, *, deadline_s: float) -> int:
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
# change both (or introduce a branch here).
obj_ref = self._actor_handle.get_num_ongoing_requests.remote()
try:
queue_len = await obj_ref
accepted = queue_len < self._replica_info.max_concurrent_queries
return queue_len, accepted
return await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
raise

def _send_query_java(self, query: Query) -> ray.ObjectRef:
"""Send the query to a Java replica.
def _send_request_java(self, pr: PendingRequest) -> ObjectRef:
"""Send the request to a Java replica.
Does not currently support streaming.
"""
if query.metadata.is_streaming:
if pr.metadata.is_streaming:
raise RuntimeError("Streaming not supported for Java.")

if len(query.args) != 1:
if len(pr.args) != 1:
raise ValueError("Java handle calls only support a single argument.")

return self._actor_handle.handle_request.remote(
RequestMetadataProto(
request_id=query.metadata.request_id,
endpoint=query.metadata.endpoint,
request_id=pr.metadata.request_id,
endpoint=pr.metadata.endpoint,
# Default call method in java is "call," not "__call__" like Python.
call_method="call"
if query.metadata.call_method == "__call__"
else query.metadata.call_method,
if pr.metadata.call_method == "__call__"
else pr.metadata.call_method,
).SerializeToString(),
query.args,
pr.args,
)

def _send_query_python(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
"""Send the query to a Python replica."""
if query.metadata.is_streaming:
def _send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> Union[ray.ObjectRef, ObjectRefGenerator]:
"""Send the request to a Python replica."""
if with_rejection:
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
method = self._actor_handle.handle_request_with_rejection.options(
num_returns="streaming"
)
elif pr.metadata.is_streaming:
method = self._actor_handle.handle_request_streaming.options(
num_returns="streaming"
)
else:
method = self._actor_handle.handle_request

return method.remote(pickle.dumps(query.metadata), *query.args, **query.kwargs)
return method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)

def send_query(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
if self._replica_info.is_cross_language:
return self._send_query_java(query)
return self._send_request_java(pr)
else:
return self._send_request_python(pr, with_rejection=False)

async def send_request_with_rejection(
self,
pr: PendingRequest,
) -> Tuple[Optional[Union[ObjectRef, ObjectRefGenerator]], ReplicaQueueLengthInfo]:
assert (
not self._replica_info.is_cross_language
), "Request rejection not supported for Java."
obj_ref_gen = self._send_request_python(pr, with_rejection=True)

first_ref = await obj_ref_gen.__anext__()
queue_len_info: ReplicaQueueLengthInfo = pickle.loads(await first_ref)

if not queue_len_info.accepted:
return None, queue_len_info
elif pr.metadata.is_streaming:
return obj_ref_gen, queue_len_info
else:
return self._send_query_python(query)
# For non-streaming requests, resolve the generator to its next
# object ref, which will contain the unary response.
return await obj_ref_gen.__anext__(), queue_len_info


@dataclass(frozen=True)
class ReplicaQueueLengthCacheEntry:
queue_len: int
timestamp: float


class ReplicaQueueLengthCache:
def __init__(
self,
*,
staleness_timeout_s: float = RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
get_curr_time_s: Optional[Callable[[], float]] = None,
):
self._cache: Dict[str, ReplicaQueueLengthCacheEntry] = {}
self._staleness_timeout_s = staleness_timeout_s
self._get_curr_time_s = (
get_curr_time_s if get_curr_time_s is not None else time.time
)

def _is_timed_out(self, timestamp_s: int) -> bool:
return self._get_curr_time_s() - timestamp_s > self._staleness_timeout_s

def get(self, replica_id: str) -> Optional[int]:
"""Get the queue length for a replica ID.
Returns `None` if the replica ID is not present or the entry is timed out.
"""
entry = self._cache.get(replica_id)
if entry is None or self._is_timed_out(entry.timestamp):
return None

return entry.queue_len

def update(self, replica_id: str, queue_len: int):
"""Set (or update) the queue length for a replica ID."""
self._cache[replica_id] = ReplicaQueueLengthCacheEntry(
queue_len, self._get_curr_time_s()
)

def remove_inactive_replicas(self, *, active_replica_ids: Set[str]):
"""Removes entries for all replica IDs not in the provided active set."""
# NOTE: the size of the cache dictionary changes during this loop.
for replica_id in list(self._cache.keys()):
if replica_id not in active_replica_ids:
self._cache.pop(replica_id)


class ReplicaScheduler(ABC):
"""Abstract interface for a replica scheduler (how the router calls it)."""

@abstractmethod
async def assign_replica(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
async def choose_replica_for_request(
self, pending_request: PendingRequest, *, is_retry: bool = False
) -> ReplicaWrapper:
pass

@abstractmethod
Expand All @@ -152,6 +252,11 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
"""Compatibility shim for RunningReplicaInfo datatype."""
return self.update_replicas([ActorReplicaWrapper(r) for r in running_replicas])

@property
@abstractmethod
def replica_queue_len_cache(self) -> ReplicaQueueLengthCache:
pass

@property
@abstractmethod
def curr_replicas(self) -> Dict[str, ReplicaWrapper]:
Expand Down
Loading

0 comments on commit d8b0fe9

Please sign in to comment.