Skip to content

Commit

Permalink
[serve] prevent in memory metric store in handles from growing in memory
Browse files Browse the repository at this point in the history
There are two potential sources of memory leak for the `InMemoryMetricsStore` in the handles that's used to record/report autoscaling metrics:
1. Old replica ID keys are never removed. We remove old replica keys from `num_queries_sent_to_replicas` when we get an updated list of running replicas from the long poll update, but we don't do any such cleaning for the in memory metrics store. This means there is leftover, uncleaned data for replicas that are no longer running.
2. We don't delete data points recorded from more than `look_back_period_s` ago for replicas except during window avg queries. This should mostly be solved once (1) is solved because this should only be a problem for replicas that are no longer running.

This PR addresses (1) and (2) by periodically pruning keys that haven't had updated data points in the past `look_back_period_s`.

Closes #44870.

Signed-off-by: Cindy Zhang <[email protected]>
  • Loading branch information
zcin committed Apr 23, 2024
1 parent ba340b7 commit f0145d9
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 22 deletions.
3 changes: 3 additions & 0 deletions python/ray/serve/_private/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@
# How often autoscaling metrics are recorded on Serve replicas.
RAY_SERVE_REPLICA_AUTOSCALING_METRIC_RECORD_PERIOD_S = 0.5

# How often autoscaling metrics are recorded on Serve handles.
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S = 0.5

# Serve multiplexed matching timeout.
# This is the timeout for the matching process of multiplexed requests. To avoid
# thundering herd problem, the timeout value will be randomed between this value
Expand Down
14 changes: 14 additions & 0 deletions python/ray/serve/_private/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def add_metrics_point(self, data_points: Dict[Hashable, float], timestamp: float
# Using in-sort to insert while maintaining sorted ordering.
bisect.insort(a=self.data[name], x=TimeStampedValue(timestamp, value))

def prune_keys_and_compact_data(self, start_timestamp_s: float):
"""Prune keys and compact data that are outdated.
For keys that haven't had new data recorded after the timestamp,
remove them from the database.
For keys that have, compact the datapoints that were recorded
before the timestamp.
"""
for key, datapoints in list(self.data.items()):
if len(datapoints) == 0 or datapoints[-1].timestamp < start_timestamp_s:
del self.data[key]
else:
self.data[key] = self._get_datapoints(key, start_timestamp_s)

def _get_datapoints(
self, key: Hashable, window_start_timestamp_s: float
) -> List[float]:
Expand Down
31 changes: 22 additions & 9 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE,
RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS,
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
SERVE_LOGGER_NAME,
)
Expand Down Expand Up @@ -164,7 +165,7 @@ def update_running_replicas(self, running_replicas: List[RunningReplicaInfo]):
)

@property
def curr_autoscaling_config(self) -> Optional[AutoscalingConfig]:
def autoscaling_config(self) -> Optional[AutoscalingConfig]:
if self.deployment_config is None:
return None

Expand All @@ -178,7 +179,7 @@ def update_deployment_config(
self.deployment_config = deployment_config

# Start the metrics pusher if autoscaling is enabled.
autoscaling_config = self.curr_autoscaling_config
autoscaling_config = self.autoscaling_config
if autoscaling_config:
self.metrics_pusher.start()
# Optimization for autoscaling cold start time. If there are
Expand All @@ -194,7 +195,10 @@ def update_deployment_config(
self.metrics_pusher.register_or_update_task(
self.RECORD_METRICS_TASK_NAME,
self._add_autoscaling_metrics_point,
min(0.5, autoscaling_config.metrics_interval_s),
min(
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S,
autoscaling_config.metrics_interval_s,
),
)
# Push metrics to the controller periodically.
self.metrics_pusher.register_or_update_task(
Expand Down Expand Up @@ -231,7 +235,7 @@ def inc_num_running_requests_for_replica(self, replica_id: ReplicaID):
sum(self.num_requests_sent_to_replicas.values())
)

def process_finished_request(self, replica_id: ReplicaID, *args):
def dec_num_running_requests_for_replica(self, replica_id: ReplicaID, *args):
with self._queries_lock:
self.num_requests_sent_to_replicas[replica_id] -= 1
self.num_running_requests_gauge.set(
Expand All @@ -240,7 +244,7 @@ def process_finished_request(self, replica_id: ReplicaID, *args):

def should_send_scaled_to_zero_optimized_push(self, curr_num_replicas: int) -> bool:
return (
self.curr_autoscaling_config is not None
self.autoscaling_config is not None
and curr_num_replicas == 0
and self.num_queued_requests > 0
)
Expand All @@ -261,6 +265,11 @@ def push_autoscaling_metrics_to_controller(self):
)

def _add_autoscaling_metrics_point(self):
"""Adds metrics point for queued and running requests at replicas.
Also prunes keys in the in memory metrics store with outdated datapoints.
"""

timestamp = time.time()
self.metrics_store.add_metrics_point(
{QUEUED_REQUESTS_KEY: self.num_queued_requests}, timestamp
Expand All @@ -270,11 +279,14 @@ def _add_autoscaling_metrics_point(self):
self.num_requests_sent_to_replicas, timestamp
)

# Prevent in memory metrics store memory from growing
start_timestamp = time.time() - self.autoscaling_config.look_back_period_s
self.metrics_store.prune_keys_and_compact_data(start_timestamp)

def _get_aggregated_requests(self):
running_requests = dict()
autoscaling_config = self.curr_autoscaling_config
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE and autoscaling_config:
look_back_period = autoscaling_config.look_back_period_s
if RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE and self.autoscaling_config:
look_back_period = self.autoscaling_config.look_back_period_s
running_requests = {
replica_id: self.metrics_store.window_average(
replica_id, time.time() - look_back_period
Expand Down Expand Up @@ -538,7 +550,8 @@ async def assign_request(
replica_id
)
callback = partial(
self._metrics_manager.process_finished_request, replica_id
self._metrics_manager.dec_num_running_requests_for_replica,
replica_id,
)
if isinstance(ref, (ray.ObjectRef, FakeObjectRef)):
ref._on_completed(callback)
Expand Down
11 changes: 11 additions & 0 deletions python/ray/serve/tests/unit/test_metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,17 @@ def test_multiple_metrics(self):
assert s.max("m1", window_start_timestamp_s=0) == 2
assert s.max("m2", window_start_timestamp_s=0) == -1

def test_prune_keys_and_compact_data(self):
s = InMemoryMetricsStore()
s.add_metrics_point({"m1": 1, "m2": 2, "m3": 8, "m4": 5}, timestamp=1)
s.add_metrics_point({"m1": 2, "m2": 3, "m3": 8}, timestamp=2)
s.add_metrics_point({"m1": 2, "m2": 5}, timestamp=3)
s.prune_keys_and_compact_data(1.1)
assert set(s.data) == {"m1", "m2", "m3"}
assert len(s.data["m1"]) == 2 and s.data["m1"] == s._get_datapoints("m1", 1.1)
assert len(s.data["m2"]) == 2 and s.data["m2"] == s._get_datapoints("m2", 1.1)
assert len(s.data["m3"]) == 1 and s.data["m3"] == s._get_datapoints("m3", 1.1)


if __name__ == "__main__":
sys.exit(pytest.main(["-v", "-s", __file__]))
94 changes: 81 additions & 13 deletions python/ray/serve/tests/unit/test_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import random
import sys
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, Union
from unittest.mock import Mock, patch

import pytest

from ray._private.test_utils import async_wait_for_condition
from ray._private.utils import get_or_create_event_loop
from ray.serve._private.common import (
DeploymentHandleSource,
Expand All @@ -17,18 +18,15 @@
RunningReplicaInfo,
)
from ray.serve._private.config import DeploymentConfig
from ray.serve._private.constants import RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE
from ray.serve._private.replica_scheduler import (
PendingRequest,
ReplicaScheduler,
ReplicaWrapper,
)
from ray.serve._private.replica_scheduler.pow_2_scheduler import ReplicaQueueLengthCache
from ray.serve._private.router import Router, RouterMetricsManager
from ray.serve._private.test_utils import ( # FakeObjectRef,; FakeObjectRefGen,
FakeCounter,
FakeGauge,
MockTimer,
)
from ray.serve._private.router import QUEUED_REQUESTS_KEY, Router, RouterMetricsManager
from ray.serve._private.test_utils import FakeCounter, FakeGauge, MockTimer
from ray.serve._private.utils import FakeObjectRef, FakeObjectRefGen, get_random_string
from ray.serve.config import AutoscalingConfig
from ray.serve.exceptions import BackPressureError
Expand Down Expand Up @@ -85,9 +83,6 @@ def __init__(self):
self._replica_to_return_on_retry: Optional[FakeReplica] = None
self._replica_queue_len_cache = ReplicaQueueLengthCache()

def set_should_block_requests(self, block_requests: bool):
self._block_requests = block_requests

@property
def replica_queue_len_cache(self) -> ReplicaQueueLengthCache:
return self._replica_queue_len_cache
Expand All @@ -105,7 +100,10 @@ def curr_replicas(self) -> Dict[str, ReplicaWrapper]:
return replicas

def update_replicas(self, replicas: List[ReplicaWrapper]):
raise NotImplementedError
pass

def set_should_block_requests(self, block_requests: bool):
self._block_requests = block_requests

def set_replica_to_return(self, replica: FakeReplica):
self._replica_to_return = replica
Expand Down Expand Up @@ -608,9 +606,9 @@ def test_track_requests_sent_to_replicas(self):

# Requests at r1 and r2 drop to 0
for _ in range(1):
metrics_manager.process_finished_request(r1, None)
metrics_manager.dec_num_running_requests_for_replica(r1, None)
for _ in range(2):
metrics_manager.process_finished_request(r2, None)
metrics_manager.dec_num_running_requests_for_replica(r2, None)
assert metrics_manager.num_requests_sent_to_replicas[r1] == 0
assert metrics_manager.num_requests_sent_to_replicas[r2] == 0

Expand Down Expand Up @@ -731,6 +729,76 @@ def test_push_autoscaling_metrics_to_controller(self):
send_timestamp=start,
)

@pytest.mark.skipif(
not RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
reason="Tests handle metrics behavior.",
)
@pytest.mark.asyncio
@patch(
"ray.serve._private.router.RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S",
0.01,
)
async def test_memory_cleared(self):
deployment_id = DeploymentID(name="a", app_name="b")
metrics_manager = RouterMetricsManager(
deployment_id,
"some_handle",
"some_actor",
DeploymentHandleSource.PROXY,
Mock(),
FakeCounter(
tag_keys=(
"deployment",
"route",
"application",
"handle",
"actor_id",
)
),
FakeGauge(tag_keys=("deployment", "application", "handle", "actor_id")),
FakeGauge(tag_keys=("deployment", "application", "handle", "actor_id")),
)
metrics_manager.update_deployment_config(
deployment_config=DeploymentConfig(
autoscaling_config=AutoscalingConfig(look_back_period_s=0.01)
),
curr_num_replicas=0,
)

r1 = ReplicaID("r1", deployment_id)
r2 = ReplicaID("r2", deployment_id)
r3 = ReplicaID("r3", deployment_id)

def check_database(expected: Set[ReplicaID]):
assert set(metrics_manager.metrics_store.data) == expected
return True

# r1: 1
metrics_manager.inc_num_running_requests_for_replica(r1)
await async_wait_for_condition(
check_database, expected={r1, QUEUED_REQUESTS_KEY}
)

# r1: 1, r2: 0
metrics_manager.inc_num_running_requests_for_replica(r2)
await async_wait_for_condition(
check_database, expected={r1, r2, QUEUED_REQUESTS_KEY}
)
metrics_manager.dec_num_running_requests_for_replica(r2)

# r1: 1, r2: 0, r3: 0
metrics_manager.inc_num_running_requests_for_replica(r3)
await async_wait_for_condition(
check_database, expected={r1, r2, r3, QUEUED_REQUESTS_KEY}
)
metrics_manager.dec_num_running_requests_for_replica(r3)

# update running replicas {r2}
metrics_manager.update_running_replicas([running_replica_info(r2)])
await async_wait_for_condition(
check_database, expected={r1, r2, QUEUED_REQUESTS_KEY}
)

@patch(
"ray.serve._private.router.RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE", "1"
)
Expand Down

0 comments on commit f0145d9

Please sign in to comment.