Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serve] Add replica queue length caching to replica scheduler #42943

Merged
merged 39 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/ray/serve/_private/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,9 @@ class TargetCapacityDirection(str, Enum):

UP = "UP"
DOWN = "DOWN"


@dataclass(frozen=True)
class ReplicaQueueLengthInfo:
accepted: bool
num_ongoing_requests: int
18 changes: 18 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@
os.environ.get("RAY_SERVE_MAX_QUEUE_LENGTH_RESPONSE_DEADLINE_S", 1.0)
)

# Feature flag for caching queue lengths for faster routing in each handle.
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE = (
os.environ.get("RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE", "0") == "1"
)

# Feature flag for strictly enforcing max_concurrent_queries (replicas will reject
# requests).
RAY_SERVE_ENABLE_STRICT_MAX_CONCURRENT_QUERIES = (
os.environ.get("RAY_SERVE_ENABLE_STRICT_MAX_CONCURRENT_QUERIES", "0") == "1"
# Strict enforcement path must be enabled for the queue length cache.
or RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE
)

# Length of time to respect entries in the queue length cache when scheduling requests.
RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S = float(
os.environ.get("RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S", 10.0)
)

# The default autoscaling policy to use if none is specified.
DEFAULT_AUTOSCALING_POLICY = "ray.serve.autoscaling_policy:default_autoscaling_policy"

Expand Down
58 changes: 57 additions & 1 deletion python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ray.serve._private.common import (
DeploymentID,
ReplicaName,
ReplicaQueueLengthInfo,
ReplicaTag,
RequestMetadata,
ServeComponentType,
Expand Down Expand Up @@ -405,7 +406,7 @@ async def handle_request(
*request_args,
**request_kwargs,
) -> Tuple[bytes, Any]:
"""Entrypoint for all `stream=False` calls."""
"""Entrypoint for `stream=False` calls."""
request_metadata = pickle.loads(pickled_request_metadata)
with self._wrap_user_method_call(request_metadata):
return await self._user_callable_wrapper.call_user_method(
Expand Down Expand Up @@ -495,6 +496,61 @@ async def handle_request_streaming(
):
yield result

async def handle_request_with_rejection(
self,
pickled_request_metadata: bytes,
*request_args,
**request_kwargs,
) -> AsyncGenerator[Any, None]:
"""Entrypoint for all requests with strict max_concurrent_queries enforcement.

The first response from this generator is always a system message indicating
if the request was accepted (the replica has capacity for the request) or
rejected (the replica is already at max_concurrent_queries).

For non-streaming requests, there will only be one more message, the unary
result of the user request handler.

For streaming requests, the subsequent messages will be the results of the
user request handler (which must be a generator).
"""
request_metadata = pickle.loads(pickled_request_metadata)
limit = self._deployment_config.max_concurrent_queries
num_ongoing_requests = self.get_num_ongoing_requests()
if num_ongoing_requests >= limit:
logger.warning(
f"Replica at capacity of max_concurrent_queries={limit}, "
f"rejecting request {request_metadata.request_id}."
)
yield pickle.dumps(
ReplicaQueueLengthInfo(
accepted=False, num_ongoing_requests=num_ongoing_requests
)
)
return

with self._wrap_user_method_call(request_metadata):
yield pickle.dumps(
ReplicaQueueLengthInfo(
accepted=True,
# NOTE(edoakes): `_wrap_user_method_call` will increment the number
# of ongoing requests to include this one, so re-fetch the value.
num_ongoing_requests=self.get_num_ongoing_requests(),
)
)

if request_metadata.is_streaming:
async for result in self._call_user_generator(
request_metadata,
request_args,
request_kwargs,
):
yield result
else:
yield await self._user_callable_wrapper.call_user_method(
request_metadata, request_args, request_kwargs
)

async def handle_request_from_java(
self,
proto_request_metadata: bytes,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/_private/replica_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ray.serve._private.replica_scheduler.common import ( # noqa: F401
Query,
PendingRequest,
ReplicaScheduler,
ReplicaWrapper,
)
Expand Down
183 changes: 144 additions & 39 deletions python/ray/serve/_private/replica_scheduler/common.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,42 @@
import asyncio
import logging
import pickle
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import ray
from ray.serve._private.common import RequestMetadata, RunningReplicaInfo
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray import ObjectRef, ObjectRefGenerator
from ray.serve._private.common import (
ReplicaQueueLengthInfo,
RequestMetadata,
RunningReplicaInfo,
)
from ray.serve._private.constants import (
RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
SERVE_LOGGER_NAME,
)
from ray.serve._private.utils import JavaActorHandleProxy
from ray.serve.generated.serve_pb2 import RequestMetadata as RequestMetadataProto

logger = logging.getLogger(SERVE_LOGGER_NAME)


@dataclass
class Query:
@dataclass(frozen=True)
class PendingRequest:
args: List[Any]
kwargs: Dict[Any, Any]
metadata: RequestMetadata
created_at: float = field(default_factory=time.time)
future: asyncio.Future = field(default_factory=lambda: asyncio.Future())

def __eq__(self, other: Any) -> bool:
"""Request ID is expected to be unique."""
if isinstance(other, PendingRequest):
return self.metadata.request_id == other.metadata.request_id

return False


class ReplicaWrapper(ABC):
Expand All @@ -37,17 +55,20 @@ def multiplexed_model_ids(self) -> Set[str]:
"""Set of model IDs on this replica."""
pass

async def get_queue_state(self, *, deadline_s: float) -> Tuple[int, bool]:
"""Returns tuple of (queue_len, accepted).
@property
def max_concurrent_requests(self) -> int:
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
"""Max concurrent requests that can be sent to this replica."""
pass

async def get_queue_len(self, *, deadline_s: float) -> int:
"""Returns current queue len for the replica.

`deadline_s` is passed to verify backoff for testing.
"""
pass

def send_query(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
"""Send query to this replica."""
def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
"""Send request to this replica."""
pass


Expand Down Expand Up @@ -77,71 +98,150 @@ def availability_zone(self) -> Optional[str]:
def multiplexed_model_ids(self) -> Set[str]:
return self._multiplexed_model_ids

async def get_queue_state(self, *, deadline_s: float) -> Tuple[int, bool]:
@property
def max_concurrent_requests(self) -> int:
return self._replica_info.max_concurrent_queries

@property
def is_cross_language(self) -> bool:
return self._replica_info.is_cross_language

async def get_queue_len(self, *, deadline_s: float) -> int:
# NOTE(edoakes): the `get_num_ongoing_requests` method name is shared by
# the Python and Java replica implementations. If you change it, you need to
# change both (or introduce a branch here).
obj_ref = self._actor_handle.get_num_ongoing_requests.remote()
try:
queue_len = await obj_ref
accepted = queue_len < self._replica_info.max_concurrent_queries
return queue_len, accepted
return await obj_ref
except asyncio.CancelledError:
ray.cancel(obj_ref)
raise

def _send_query_java(self, query: Query) -> ray.ObjectRef:
"""Send the query to a Java replica.
def _send_request_java(self, pr: PendingRequest) -> ObjectRef:
"""Send the request to a Java replica.

Does not currently support streaming.
"""
if query.metadata.is_streaming:
if pr.metadata.is_streaming:
raise RuntimeError("Streaming not supported for Java.")

if len(query.args) != 1:
if len(pr.args) != 1:
raise ValueError("Java handle calls only support a single argument.")

return self._actor_handle.handle_request.remote(
RequestMetadataProto(
request_id=query.metadata.request_id,
endpoint=query.metadata.endpoint,
request_id=pr.metadata.request_id,
endpoint=pr.metadata.endpoint,
# Default call method in java is "call," not "__call__" like Python.
call_method="call"
if query.metadata.call_method == "__call__"
else query.metadata.call_method,
if pr.metadata.call_method == "__call__"
else pr.metadata.call_method,
).SerializeToString(),
query.args,
pr.args,
)

def _send_query_python(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
"""Send the query to a Python replica."""
if query.metadata.is_streaming:
def _send_request_python(
self, pr: PendingRequest, *, with_rejection: bool
) -> Union[ray.ObjectRef, ObjectRefGenerator]:
"""Send the request to a Python replica."""
if with_rejection:
edoakes marked this conversation as resolved.
Show resolved Hide resolved
# Call a separate handler that may reject the request.
# This handler is *always* a streaming call and the first message will
# be a system message that accepts or rejects.
method = self._actor_handle.handle_request_with_rejection.options(
num_returns="streaming"
)
elif pr.metadata.is_streaming:
method = self._actor_handle.handle_request_streaming.options(
num_returns="streaming"
)
else:
method = self._actor_handle.handle_request

return method.remote(pickle.dumps(query.metadata), *query.args, **query.kwargs)
return method.remote(pickle.dumps(pr.metadata), *pr.args, **pr.kwargs)

def send_query(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
def send_request(self, pr: PendingRequest) -> Union[ObjectRef, ObjectRefGenerator]:
if self._replica_info.is_cross_language:
return self._send_query_java(query)
return self._send_request_java(pr)
else:
return self._send_request_python(pr, with_rejection=False)
edoakes marked this conversation as resolved.
Show resolved Hide resolved

async def send_request_with_rejection(
self,
pr: PendingRequest,
) -> Tuple[Optional[Union[ObjectRef, ObjectRefGenerator]], ReplicaQueueLengthInfo]:
assert (
not self._replica_info.is_cross_language
), "Request rejection not supported for Java."
obj_ref_gen = self._send_request_python(pr, with_rejection=True)

first_ref = await obj_ref_gen.__anext__()
queue_len_info: ReplicaQueueLengthInfo = pickle.loads(await first_ref)

if not queue_len_info.accepted:
return None, queue_len_info
elif pr.metadata.is_streaming:
return obj_ref_gen, queue_len_info
else:
return self._send_query_python(query)
# For non-streaming requests, resolve the generator to its next
# object ref, which will contain the unary response.
return await obj_ref_gen.__anext__(), queue_len_info


@dataclass(frozen=True)
class ReplicaQueueLengthCacheEntry:
queue_len: int
timestamp: float


class ReplicaQueueLengthCache:
def __init__(
self,
*,
staleness_timeout_s: float = RAY_SERVE_QUEUE_LENGTH_CACHE_TIMEOUT_S,
get_curr_time_s: Optional[Callable[[], float]] = None,
):
self._cache: Dict[str, ReplicaQueueLengthCacheEntry] = {}
self._staleness_timeout_s = staleness_timeout_s
self._get_curr_time_s = (
get_curr_time_s if get_curr_time_s is not None else time.time
)

def _is_timed_out(self, timestamp_s: int) -> bool:
return self._get_curr_time_s() - timestamp_s > self._staleness_timeout_s

def get(self, replica_id: str) -> Optional[int]:
"""Get the queue length for a replica ID.

Returns `None` if the replica ID is not present or the entry is timed out.
"""
entry = self._cache.get(replica_id)
if entry is None or self._is_timed_out(entry.timestamp):
return None

return entry.queue_len

def update(self, replica_id: str, queue_len: int):
"""Set (or update) the queue length for a replica ID."""
self._cache[replica_id] = ReplicaQueueLengthCacheEntry(
queue_len, self._get_curr_time_s()
)

def remove_inactive_replicas(self, *, active_replica_ids: Set[str]):
"""Removes entries for all replica IDs not in the provided active set."""
# NOTE: the size of the cache dictionary changes during this loop.
for replica_id in list(self._cache.keys()):
if replica_id not in active_replica_ids:
self._cache.pop(replica_id)


class ReplicaScheduler(ABC):
"""Abstract interface for a replica scheduler (how the router calls it)."""

@abstractmethod
async def assign_replica(
self, query: Query
) -> Union[ray.ObjectRef, "ray._raylet.ObjectRefGenerator"]:
async def choose_replica_for_request(
self, pending_request: PendingRequest, *, is_retry: bool = False
) -> ReplicaWrapper:
pass

@abstractmethod
Expand All @@ -152,6 +252,11 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
"""Compatibility shim for RunningReplicaInfo datatype."""
return self.update_replicas([ActorReplicaWrapper(r) for r in running_replicas])

@property
@abstractmethod
def replica_queue_len_cache(self) -> ReplicaQueueLengthCache:
pass

@property
@abstractmethod
def curr_replicas(self) -> Dict[str, ReplicaWrapper]:
Expand Down
Loading
Loading