Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix the file descriptor leaking when worker or node died in GCS #33311

Merged
merged 16 commits into from
Mar 16, 2023
2 changes: 1 addition & 1 deletion dashboard/modules/reporter/reporter_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def run(self, server):
self.service_discovery.daemon = True
self.service_discovery.start()
gcs_addr = self._dashboard_head.gcs_address
subscriber = GcsAioResourceUsageSubscriber(gcs_addr)
subscriber = GcsAioResourceUsageSubscriber(address=gcs_addr)
await subscriber.subscribe()
cluster_metadata = await self._dashboard_head.gcs_aio_client.internal_kv_get(
CLUSTER_METADATA_KEY,
Expand Down
41 changes: 28 additions & 13 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def _create_node_resource_usage_request(key: str, json: str):


class _SubscriberBase:
def __init__(self):
def __init__(self, worker_id: bytes = None):
self._worker_id = worker_id
# self._subscriber_id needs to match the binary format of a random
# SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
Expand All @@ -84,7 +85,7 @@ def last_batch_size(self):
def _subscribe_request(self, channel):
cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[cmd]
subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd]
)
return req

Expand All @@ -95,7 +96,7 @@ def _poll_request(self):

def _unsubscribe_request(self, channels):
req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
subscriber_id=self._subscriber_id, commands=[]
subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[]
)
for channel in channels:
req.commands.append(
Expand Down Expand Up @@ -202,10 +203,11 @@ class _SyncSubscriber(_SubscriberBase):
def __init__(
self,
pubsub_channel_type,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__()
super().__init__(worker_id)

if address:
assert channel is None, "address and channel cannot both be specified"
Expand Down Expand Up @@ -308,10 +310,11 @@ class GcsErrorSubscriber(_SyncSubscriber):

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, address, channel)
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error messages.
Expand Down Expand Up @@ -342,10 +345,11 @@ class GcsLogSubscriber(_SyncSubscriber):

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> Optional[dict]:
"""Polls for new log messages.
Expand Down Expand Up @@ -376,10 +380,13 @@ class GcsFunctionKeySubscriber(_SyncSubscriber):

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL, address, channel)
super().__init__(
pubsub_pb2.RAY_PYTHON_FUNCTION_CHANNEL, worker_id, address, channel
)

def poll(self, timeout=None) -> Optional[bytes]:
"""Polls for new function key messages.
Expand Down Expand Up @@ -411,10 +418,11 @@ class GcsActorSubscriber(_SyncSubscriber):

def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, address, channel)
super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)

def poll(self, timeout=None) -> List[Tuple[bytes, str]]:
"""Polls for new actor messages.
Expand Down Expand Up @@ -475,10 +483,11 @@ class _AioSubscriber(_SubscriberBase):
def __init__(
self,
pubsub_channel_type,
worker_id: bytes = None,
address: str = None,
channel: aiogrpc.Channel = None,
):
super().__init__()
super().__init__(worker_id)

if address:
assert channel is None, "address and channel cannot both be specified"
Expand Down Expand Up @@ -554,10 +563,11 @@ async def close(self) -> None:
class GcsAioErrorSubscriber(_AioSubscriber):
def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, address, channel)
super().__init__(pubsub_pb2.RAY_ERROR_INFO_CHANNEL, worker_id, address, channel)

async def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
"""Polls for new error message.
Expand All @@ -573,10 +583,11 @@ async def poll(self, timeout=None) -> Tuple[bytes, ErrorTableData]:
class GcsAioLogSubscriber(_AioSubscriber):
def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, address, channel)
super().__init__(pubsub_pb2.RAY_LOG_CHANNEL, worker_id, address, channel)

async def poll(self, timeout=None) -> dict:
"""Polls for new log message.
Expand All @@ -592,10 +603,13 @@ async def poll(self, timeout=None) -> dict:
class GcsAioResourceUsageSubscriber(_AioSubscriber):
def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, address, channel)
super().__init__(
pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel
)

async def poll(self, timeout=None) -> Tuple[bytes, str]:
"""Polls for new resource usage message.
Expand All @@ -610,10 +624,11 @@ async def poll(self, timeout=None) -> Tuple[bytes, str]:
class GcsAioActorSubscriber(_AioSubscriber):
def __init__(
self,
worker_id: bytes = None,
address: str = None,
channel: grpc.Channel = None,
):
super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, address, channel)
super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)

@property
def queue_size(self):
Expand Down
25 changes: 16 additions & 9 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,12 +1973,6 @@ def connect(
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
)
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
worker.gcs_error_subscriber = GcsErrorSubscriber(address=worker.gcs_client.address)
worker.gcs_log_subscriber = GcsLogSubscriber(address=worker.gcs_client.address)
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
address=worker.gcs_client.address
)

# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
# We should not specify the job_id if it's `WORKER_MODE`.
Expand Down Expand Up @@ -2127,6 +2121,16 @@ def connect(

# Notify raylet that the core worker is ready.
worker.core_worker.notify_raylet()
worker_id = worker.worker_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this same as core worker's worker_id? I have some memory it was actually different...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it comes from here: https://sourcegraph.com/github.com/ray-project/ray@master/-/blob/python/ray/_raylet.pyx?L1609:9

It's from core worker, so I think it should be the same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You are right. Maybe I was confused with the node id or sth...

worker.gcs_error_subscriber = GcsErrorSubscriber(
worker_id=worker_id, address=worker.gcs_client.address
)
worker.gcs_log_subscriber = GcsLogSubscriber(
worker_id=worker_id, address=worker.gcs_client.address
)
worker.gcs_function_key_subscriber = GcsFunctionKeySubscriber(
worker_id=worker_id, address=worker.gcs_client.address
)

if driver_object_store_memory is not None:
logger.warning(
Expand Down Expand Up @@ -2212,9 +2216,12 @@ def disconnect(exiting_interpreter=False):
# should be handled cleanly in the worker object's destructor and not
# in this disconnect method.
worker.threads_stopped.set()
worker.gcs_function_key_subscriber.close()
worker.gcs_error_subscriber.close()
worker.gcs_log_subscriber.close()
if hasattr(worker, "gcs_function_key_subscriber"):
worker.gcs_function_key_subscriber.close()
if hasattr(worker, "gcs_error_subscriber"):
worker.gcs_error_subscriber.close()
if hasattr(worker, "gcs_log_subscriber"):
worker.gcs_log_subscriber.close()
if hasattr(worker, "import_thread"):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
Expand Down
12 changes: 9 additions & 3 deletions python/ray/serve/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,11 +693,13 @@ def test_run_config_port1(ray_start_stop):
config_file_name = os.path.join(
os.path.dirname(__file__), "test_config_files", "basic_graph.yaml"
)
subprocess.Popen(["serve", "run", config_file_name])
p = subprocess.Popen(["serve", "run", config_file_name])
wait_for_condition(
lambda: requests.post("http://localhost:8000/").text == "wonderful world",
timeout=15,
)
p.send_signal(signal.SIGINT)
p.wait()


@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
Expand All @@ -706,11 +708,13 @@ def test_run_config_port2(ray_start_stop):
config_file_name = os.path.join(
os.path.dirname(__file__), "test_config_files", "basic_graph_http.yaml"
)
subprocess.Popen(["serve", "run", config_file_name])
p = subprocess.Popen(["serve", "run", config_file_name])
wait_for_condition(
lambda: requests.post("http://localhost:8005/").text == "wonderful world",
timeout=15,
)
p.send_signal(signal.SIGINT)
p.wait()


@pytest.mark.skipif(sys.platform == "win32", reason="File path incorrect on Windows.")
Expand All @@ -719,11 +723,13 @@ def test_run_config_port3(ray_start_stop):
config_file_name = os.path.join(
os.path.dirname(__file__), "test_config_files", "basic_graph_http.yaml"
)
subprocess.Popen(["serve", "run", "--port=8010", config_file_name])
p = subprocess.Popen(["serve", "run", "--port=8010", config_file_name])
wait_for_condition(
lambda: requests.post("http://localhost:8010/").text == "wonderful world",
timeout=15,
)
p.send_signal(signal.SIGINT)
p.wait()


@serve.deployment
Expand Down
51 changes: 51 additions & 0 deletions python/ray/tests/test_advanced_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,57 @@ def ready(self):
run_string_as_driver(script.format(address=call_ray_start_2, val=2))


@pytest.mark.skipif(sys.platform != "linux", reason="Only works on linux.")
def test_gcs_connection_no_leak(ray_start_cluster):
cluster = ray_start_cluster
head_node = cluster.add_node()

gcs_server_process = head_node.all_processes["gcs_server"][0].process
gcs_server_pid = gcs_server_process.pid

ray.init(cluster.address)

def get_gcs_num_of_connections():
import psutil

p = psutil.Process(gcs_server_pid)
print(">>", p.num_fds())
return p.num_fds()

# Wait for everything to be ready.
import time

time.sleep(10)

curr_fds = get_gcs_num_of_connections()

@ray.remote
class A:
def ready(self):
print("HELLO")
return "WORLD"

num_of_actors = 10
a = [A.remote() for _ in range(num_of_actors)]
print(ray.get([t.ready.remote() for t in a]))

# Kill the actor
del a

# Make sure the # of fds opened by the GCS dropped.
wait_for_condition(lambda: get_gcs_num_of_connections() == curr_fds)

n = cluster.add_node(wait=True)

# Make sure the # of fds opened by the GCS increased.
wait_for_condition(lambda: get_gcs_num_of_connections() > curr_fds)

cluster.remove_node(n)

# Make sure the # of fds opened by the GCS dropped.
wait_for_condition(lambda: get_gcs_num_of_connections() == curr_fds)


@pytest.mark.parametrize(
"call_ray_start",
["ray start --head --num-cpus=2"],
Expand Down
2 changes: 1 addition & 1 deletion src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
<< rpc_address_.port() << ", worker ID " << worker_context_.GetWorkerID()
<< ", raylet " << local_raylet_id;

gcs_client_ = std::make_shared<gcs::GcsClient>(options_.gcs_options);
gcs_client_ = std::make_shared<gcs::GcsClient>(options_.gcs_options, GetWorkerID());

RAY_CHECK_OK(gcs_client_->Connect(io_service_));
RegisterToGcs();
Expand Down
3 changes: 2 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ void GcsSubscriberClient::PubsubCommandBatch(

} // namespace

GcsClient::GcsClient(const GcsClientOptions &options) : options_(options) {}
GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id)
: options_(options), gcs_client_id_(gcs_client_id) {}

Status GcsClient::Connect(instrumented_io_context &io_service) {
// Connect to gcs service.
Expand Down
6 changes: 5 additions & 1 deletion src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,11 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this<GcsClient> {
/// Constructor of GcsClient.
///
/// \param options Options for client.
explicit GcsClient(const GcsClientOptions &options);
/// \param gcs_client_id The unique ID for the owner of this object.
/// This potentially will be used to tell GCS who is client connecting
/// to GCS.
explicit GcsClient(const GcsClientOptions &options,
UniqueID gcs_client_id = UniqueID::FromRandom());

virtual ~GcsClient() { Disconnect(); };

Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_server/gcs_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,7 @@ void GcsServer::InstallEventListeners() {
gcs_actor_manager_->OnNodeDead(node_id, node_ip_address);
raylet_client_pool_->Disconnect(node_id);
gcs_healthcheck_manager_->RemoveNode(node_id);
pubsub_handler_->RemoveSubscriberFrom(node_id.Binary());

if (!RayConfig::instance().use_ray_syncer()) {
gcs_ray_syncer_->RemoveNode(*node);
Expand All @@ -667,6 +668,7 @@ void GcsServer::InstallEventListeners() {
worker_failure_data->exit_detail(),
creation_task_exception);
gcs_placement_group_scheduler_->HandleWaitingRemovedBundles();
pubsub_handler_->RemoveSubscriberFrom(worker_id.Binary());
});

// Install job event listeners.
Expand Down
Loading