Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] log warning to reconfigure max_ongoing_requests if max_batch_size is less than max_ongoing_requests #43840

Merged
merged 14 commits into from
Mar 14, 2024
7 changes: 7 additions & 0 deletions python/ray/serve/_private/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _set_internal_replica_context(self, *, servable_object: Callable = None):
ray.serve.context._set_internal_replica_context(
replica_id=self._replica_id,
servable_object=servable_object,
_deployment_config=self._deployment_config,
)

def _configure_logger_and_profilers(
Expand Down Expand Up @@ -636,6 +637,12 @@ async def reconfigure(
deployment_config.user_config
)

# We need to update internal replica context to reflect the new
# deployment_config.
self._set_internal_replica_context(
servable_object=self._user_callable_wrapper.user_callable
)

return self._get_metadata()
except Exception:
raise RuntimeError(traceback.format_exc()) from None
Expand Down
25 changes: 24 additions & 1 deletion python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
overload,
)

from ray import serve
from ray._private.signature import extract_signature, flatten_args, recover_args
from ray._private.utils import get_or_create_event_loop
from ray.serve._private.constants import SERVE_LOGGER_NAME
Expand Down Expand Up @@ -111,6 +112,28 @@ def __init__(
self._handle_batch_task = self._loop.create_task(
self._process_batches(handle_batch_func)
)
self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

def _warn_if_max_batch_size_exceeds_max_ongoing_requests(self):
"""Helper to check whether the max_batch_size is bounded.

Log a warning to configure `max_ongoing_requests` if it's bounded.
"""
max_ongoing_requests = (
serve.get_replica_context()._deployment_config.max_ongoing_requests
)
if max_ongoing_requests < self.max_batch_size:
logger.warning(
f"`max_batch_size` ({self.max_batch_size}) is larger than "
f"`max_ongoing_requests` ({max_ongoing_requests}). This means "
"the replica will never receive a full batch. Please update "
"`max_ongoing_requests` to be >= `max_batch_size`."
)

def set_max_batch_size(self, new_max_batch_size: int) -> None:
"""Updates queue's max_batch_size."""
self.max_batch_size = new_max_batch_size
self._warn_if_max_batch_size_exceeds_max_ongoing_requests()

def put(self, request: Tuple[_SingleRequest, asyncio.Future]) -> None:
self.queue.put_nowait(request)
Expand Down Expand Up @@ -345,7 +368,7 @@ def set_max_batch_size(self, new_max_batch_size: int) -> None:
self.max_batch_size = new_max_batch_size

if self._queue is not None:
self._queue.max_batch_size = new_max_batch_size
self._queue.set_max_batch_size(new_max_batch_size)

def set_batch_wait_timeout_s(self, new_batch_wait_timeout_s: float) -> None:
self.batch_wait_timeout_s = new_batch_wait_timeout_s
Expand Down
4 changes: 4 additions & 0 deletions python/ray/serve/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ray.exceptions import RayActorError
from ray.serve._private.client import ServeControllerClient
from ray.serve._private.common import ReplicaID
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import SERVE_CONTROLLER_NAME, SERVE_NAMESPACE
from ray.serve.exceptions import RayServeException
from ray.serve.grpc_util import RayServegRPCContext
Expand All @@ -37,6 +38,7 @@ class ReplicaContext:

replica_id: ReplicaID
servable_object: Callable
_deployment_config: DeploymentConfig

@property
def app_name(self) -> str:
Expand Down Expand Up @@ -98,11 +100,13 @@ def _set_internal_replica_context(
*,
replica_id: ReplicaID,
servable_object: Callable,
_deployment_config: DeploymentConfig,
):
global _INTERNAL_REPLICA_CONTEXT
_INTERNAL_REPLICA_CONTEXT = ReplicaContext(
replica_id=replica_id,
servable_object=servable_object,
_deployment_config=_deployment_config,
)


Expand Down
2 changes: 2 additions & 0 deletions python/ray/serve/tests/test_multiplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray._private.test_utils import SignalActor, wait_for_condition
from ray._private.utils import get_or_create_event_loop
from ray.serve._private.common import DeploymentID, ReplicaID
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import SERVE_MULTIPLEXED_MODEL_ID
from ray.serve.context import _get_internal_replica_context
from ray.serve.handle import DeploymentHandle
Expand All @@ -25,6 +26,7 @@ def start_serve_with_context():
deployment_id=DeploymentID(name="fake_deployment", app_name="fake_app"),
),
servable_object=None,
_deployment_config=DeploymentConfig(),
)
try:
yield
Expand Down
69 changes: 69 additions & 0 deletions python/ray/serve/tests/unit/test_batching.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
import asyncio
import logging
import time
from typing import List

import pytest

import ray
from ray import serve
from ray._private.utils import get_or_create_event_loop
from ray.serve._private.common import DeploymentID, ReplicaID
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import SERVE_LOGGER_NAME
from ray.serve.batching import _BatchQueue
from ray.serve.exceptions import RayServeException

# Setup the global replica context for the test.
default_deployment_config = DeploymentConfig()
ray.serve.context._set_internal_replica_context(
replica_id=ReplicaID(unique_id="test", deployment_id=DeploymentID(name="test")),
servable_object=None,
_deployment_config=default_deployment_config,
)


class FakeStream:
def __init__(self):
self.messages = []

def write(self, buf):
self.messages.append(buf)

def reset_message(self):
self.messages = []


# We use a single event loop for the entire test session. Without this
# fixture, the event loop is sometimes prematurely terminated by pytest.
Expand Down Expand Up @@ -776,6 +801,50 @@ async def yield_three_times(key1, key2):
await coro.__anext__()


def test_warn_if_max_batch_size_exceeds_max_ongoing_requests():
"""Test warn_if_max_batch_size_exceeds_max_ongoing_requests() logged the warning
message correctly.

When the queue starts with or updated `max_batch_size` to be larger than
max_ongoing_requests, log the warning to suggest configuring `max_ongoing_requests`.
When the queue starts with or updated `max_batch_size` to be smaller or equal than
max_ongoing_requests, no warning should be logged.
"""
logger = logging.getLogger(SERVE_LOGGER_NAME)
stream = FakeStream()
stream_handler = logging.StreamHandler(stream)
logger.addHandler(stream_handler)
bound = default_deployment_config.max_ongoing_requests
over_bound = bound + 1
under_bound = bound - 1
over_bound_warning_message = (
f"`max_batch_size` ({over_bound}) is larger than "
f"`max_ongoing_requests` ({bound}). This means "
"the replica will never receive a full batch. Please update "
"`max_ongoing_requests` to be >= `max_batch_size`.\n"
)

# Start queue above the bound will log warning. Start at under or at the bound will
# not log warning
for max_batch_size in [over_bound, under_bound, bound]:
queue = _BatchQueue(max_batch_size=max_batch_size, batch_wait_timeout_s=1000)
if max_batch_size > bound:
assert over_bound_warning_message in stream.messages
else:
assert over_bound_warning_message not in stream.messages
stream.reset_message()

# Update queue above the bound will log warning. Update at under or at the bound
# will not log warning
for max_batch_size in [over_bound, under_bound, bound]:
queue.set_max_batch_size(max_batch_size)
if max_batch_size > bound:
assert over_bound_warning_message in stream.messages
else:
assert over_bound_warning_message not in stream.messages
stream.reset_message()


if __name__ == "__main__":
import sys

Expand Down
Loading