Skip to content

Commit

Permalink
[core] fix ray status pending requests leaked after driver is killed (#…
Browse files Browse the repository at this point in the history
…31155)

Signed-off-by: Clarence Ng <[email protected]>

Fixes pending requests leaked as we don't clean up the backlog request when the driver is killed.

The fix is to make sure we clean up backlog whether it is worker or driver
  • Loading branch information
clarng authored Dec 20, 2022
1 parent b963198 commit 688e8a1
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
38 changes: 37 additions & 1 deletion python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import Any, Callable, Dict, List, Optional
import uuid

import requests
from ray._raylet import Config

import grpc
Expand All @@ -39,7 +41,13 @@
from ray._private.internal_api import memory_summary
from ray._private.tls_utils import generate_self_signed_tls_certs
from ray._raylet import GcsClientOptions, GlobalStateAccessor
from ray.core.generated import gcs_pb2, node_manager_pb2, node_manager_pb2_grpc
from ray.core.generated import (
gcs_pb2,
node_manager_pb2,
node_manager_pb2_grpc,
gcs_service_pb2,
gcs_service_pb2_grpc,
)
from ray.scripts.scripts import main as ray_main
from ray.util.queue import Empty, Queue, _QueueActor
from ray.experimental.state.state_manager import StateDataSourceClient
Expand Down Expand Up @@ -1608,6 +1616,34 @@ def get_node_stats(raylet, num_retry=5, timeout=2):
return reply


# Gets resource usage assuming gcs is local.
def get_resource_usage(gcs_address, timeout=10):
if not gcs_address:
gcs_address = ray.worker._global_node.gcs_address

gcs_channel = ray._private.utils.init_grpc_channel(
gcs_address, ray_constants.GLOBAL_GRPC_OPTIONS, asynchronous=False
)

gcs_node_resources_stub = gcs_service_pb2_grpc.NodeResourceInfoGcsServiceStub(
gcs_channel
)

request = gcs_service_pb2.GetAllResourceUsageRequest()
response = gcs_node_resources_stub.GetAllResourceUsage(request, timeout=timeout)
resources_batch_data = response.resource_usage_data

return resources_batch_data


# Gets the load metrics report assuming gcs is local.
def get_load_metrics_report(webui_url):
webui_url = format_web_url(webui_url)
response = requests.get(f"{webui_url}/api/cluster_status")
response.raise_for_status()
return response.json()["data"]["clusterStatus"]["loadMetricsReport"]


# Send a RPC to the raylet to have it self-destruct its process.
def kill_raylet(raylet, graceful=False):
raylet_address = f'{raylet["NodeManagerAddress"]}:{raylet["NodeManagerPort"]}'
Expand Down
72 changes: 69 additions & 3 deletions python/ray/tests/test_node_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import ray
from ray._private.test_utils import run_string_as_driver
from ray._private.test_utils import (
get_load_metrics_report,
run_string_as_driver,
run_string_as_driver_nonblocking,
wait_for_condition,
get_resource_usage,
)
import pytest
import os


# This tests the queue transitions for infeasible tasks. This has been an issue
Expand Down Expand Up @@ -46,9 +54,67 @@ def f():
ray.get([f._remote(args=[], kwargs={}, resources={str(i): 1}) for i in range(3)])


@pytest.mark.parametrize(
"call_ray_start",
["""ray start --head"""],
indirect=True,
)
def test_kill_driver_clears_backlog(call_ray_start):
driver = """
import ray
@ray.remote
def f():
import time
time.sleep(300)
refs = [f.remote() for _ in range(10000)]
ray.get(refs)
"""
proc = run_string_as_driver_nonblocking(driver)
ctx = ray.init(address=call_ray_start)

def get_backlog_and_pending():
resources_batch = get_resource_usage(
gcs_address=ctx.address_info["gcs_address"]
)
backlog = (
resources_batch.resource_load_by_shape.resource_demands[0].backlog_size
if resources_batch.resource_load_by_shape.resource_demands
else 0
)

pending = 0
demands = get_load_metrics_report(webui_url=ctx.address_info["webui_url"])[
"resourceDemand"
]
for demand in demands:
resource_dict, amount = demand
if "CPU" in resource_dict:
pending = amount

return pending, backlog

def check_backlog(expect_backlog) -> bool:
pending, backlog = get_backlog_and_pending()
if expect_backlog:
return pending > 0 and backlog > 0
else:
return pending == 0 and backlog == 0

wait_for_condition(
check_backlog, timeout=10, retry_interval_ms=1000, expect_backlog=True
)

os.kill(proc.pid, 9)

wait_for_condition(
check_backlog, timeout=10, retry_interval_ms=1000, expect_backlog=False
)


if __name__ == "__main__":
import pytest
import os
import sys

if os.environ.get("PARALLEL_CI"):
Expand Down
4 changes: 2 additions & 2 deletions src/ray/raylet/node_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1584,8 +1584,6 @@ void NodeManager::DisconnectClient(const std::shared_ptr<ClientConnection> &clie
// Return the resources that were being used by this worker.
local_task_manager_->ReleaseWorkerResources(worker);

local_task_manager_->ClearWorkerBacklog(worker->WorkerId());

// Since some resources may have been released, we can try to dispatch more tasks.
cluster_task_manager_->ScheduleAndDispatchTasks();
} else if (is_driver) {
Expand All @@ -1608,6 +1606,8 @@ void NodeManager::DisconnectClient(const std::shared_ptr<ClientConnection> &clie
}
}

local_task_manager_->ClearWorkerBacklog(worker->WorkerId());

client->Close();

// TODO(rkn): Tell the object manager that this client has disconnected so
Expand Down

0 comments on commit 688e8a1

Please sign in to comment.