Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Fix GCS FD usage increase regression. #35624

Merged
merged 23 commits into from
May 24, 2023
Merged
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ cc_grpc_library(
# gcs rpc server and client.
cc_library(
name = "gcs_service_rpc",
srcs = [
"src/ray/rpc/gcs_server/gcs_rpc_client.cc",
],
hdrs = [
"src/ray/rpc/gcs_server/gcs_rpc_client.h",
"src/ray/rpc/gcs_server/gcs_rpc_server.h",
Expand Down
43 changes: 39 additions & 4 deletions python/ray/tests/test_advanced_9.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ray.experimental.internal_kv import _internal_kv_list
from ray.tests.conftest import call_ray_start
import subprocess
import psutil


@pytest.fixture
Expand Down Expand Up @@ -269,11 +270,9 @@ def test_gcs_connection_no_leak(ray_start_cluster):
ray.init(cluster.address)

def get_gcs_num_of_connections():
import psutil

p = psutil.Process(gcs_server_pid)
print(">>", p.num_fds())
return p.num_fds()
print(">>", len(p.connections()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use connections instead of fds for better measurement of the sockets in GCS.

return len(p.connections())

# Wait for everything to be ready.
import time
Expand Down Expand Up @@ -438,6 +437,42 @@ def f():
assert ray.get(f.remote())


def test_gcs_fd_usage(shutdown_only):
ray.init(
_system_config={
"prestart_worker_first_driver": False,
"enable_worker_prestart": False,
},
)
gcs_process = ray._private.worker._global_node.all_processes["gcs_server"][0]
gcs_process = psutil.Process(gcs_process.process.pid)
print("GCS connections", len(gcs_process.connections()))

@ray.remote(runtime_env={"env_vars": {"Hello": "World"}})
class A:
def f(self):
import os

return os.environ.get("Hello")

# In case there are still some pre-start workers, consume all of them
fishbone marked this conversation as resolved.
Show resolved Hide resolved
aa = [A.remote() for _ in range(32)]
for a in aa:
assert ray.get(a.f.remote()) == "World"
base_fd_num = len(gcs_process.connections())
print("GCS connections", base_fd_num)

bb = [A.remote() for _ in range(4)]
for b in bb:
assert ray.get(b.f.remote()) == "World"
new_fd_num = len(gcs_process.connections())
print("GCS connections", new_fd_num)
# each worker has two connections:
fishbone marked this conversation as resolved.
Show resolved Hide resolved
# GCS -> CoreWorker
# CoreWorker -> GCS
assert (new_fd_num - base_fd_num) == len(bb) * 2


if __name__ == "__main__":
import pytest
import os
Expand Down
12 changes: 8 additions & 4 deletions python/ray/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,10 +416,14 @@ def test_ray_start_head_block_and_signals(
# Run
head_proc.start()

# Give it some time to start various subprocesses and `ray stop`
# A smaller interval seems to cause occasional failure as the head process
# was stopped too early before spawning all the subprocesses.
time.sleep(5)
# Wait until the system is ready
while True:
try:
ray.init()
ray.shutdown()
break
except Exception:
continue

# Terminate some of the children process
children = psutil.Process(head_proc.pid).children()
Expand Down
2 changes: 1 addition & 1 deletion src/ray/common/ray_config_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ RAY_CONFIG(int64_t, max_direct_call_object_size, 100 * 1024)
// The max gRPC message size (the gRPC internal default is 4MB). We use a higher
// limit in Ray to avoid crashing with many small inlined task arguments.
// Keep in sync with GCS_STORAGE_MAX_SIZE in packaging.py.
RAY_CONFIG(int64_t, max_grpc_message_size, 500 * 1024 * 1024)
RAY_CONFIG(int64_t, max_grpc_message_size, 512 * 1024 * 1024)

// Retry timeout for trying to create a gRPC server. Only applies if the number
// of retries is non zero.
Expand Down
4 changes: 2 additions & 2 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ std::pair<std::string, int> GcsClient::GetGcsServerAddress() const {
PythonGcsClient::PythonGcsClient(const GcsClientOptions &options) : options_(options) {}

Status PythonGcsClient::Connect() {
auto arguments = PythonGrpcChannelArguments();
channel_ = rpc::BuildChannel(options_.gcs_address_, options_.gcs_port_, arguments);
channel_ =
rpc::GcsRpcClient::GetDefaultChannel(options_.gcs_address_, options_.gcs_port_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we planning to cherry pick this btw? There's a bit of concern we change this settings. It looks like after this all python clients' timeout will be from 60 -> 30 seconds. Should we increase the default grpc_client_keepalive_time_ms to 60 seconds?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I feel if core worker's gcs client got time out, it's also considered bad, and it won't progress. Given this, to make things alive, we need both to be alive. If I understand it correctly. So it should be OK I think.

kv_stub_ = rpc::InternalKVGcsService::NewStub(channel_);
runtime_env_stub_ = rpc::RuntimeEnvGcsService::NewStub(channel_);
node_info_stub_ = rpc::NodeInfoGcsService::NewStub(channel_);
Expand Down
4 changes: 2 additions & 2 deletions src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class GcsServerTest : public ::testing::Test {

// Create gcs rpc client
client_call_manager_.reset(new rpc::ClientCallManager(io_service_));
client_.reset(
new rpc::GcsRpcClient("0.0.0.0", gcs_server_->GetPort(), *client_call_manager_));
client_.reset(new rpc::GcsRpcClient(
"127.0.0.1", gcs_server_->GetPort(), *client_call_manager_));
}

void TearDown() override {
Expand Down
12 changes: 2 additions & 10 deletions src/ray/gcs/pubsub/gcs_pub_sub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "ray/gcs/pubsub/gcs_pub_sub.h"

#include "absl/strings/str_cat.h"
#include "ray/rpc/gcs_server/gcs_rpc_client.h"
#include "ray/rpc/grpc_client.h"

namespace ray {
Expand Down Expand Up @@ -213,14 +214,6 @@ Status GcsSubscriber::SubscribeAllWorkerFailures(
return Status::OK();
}

grpc::ChannelArguments PythonGrpcChannelArguments() {
grpc::ChannelArguments arguments;
arguments.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, 512 * 1024 * 1024);
arguments.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 60 * 1000);
arguments.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000);
return arguments;
}

PythonGcsPublisher::PythonGcsPublisher(const std::string &gcs_address) {
std::vector<std::string> address = absl::StrSplit(gcs_address, ':');
RAY_LOG(DEBUG) << "Connect to gcs server via address: " << gcs_address;
Expand All @@ -230,8 +223,7 @@ PythonGcsPublisher::PythonGcsPublisher(const std::string &gcs_address) {
}

Status PythonGcsPublisher::Connect() {
auto arguments = PythonGrpcChannelArguments();
channel_ = rpc::BuildChannel(gcs_address_, gcs_port_, arguments);
channel_ = rpc::GcsRpcClient::GetDefaultChannel(gcs_address_, gcs_port_);
pubsub_stub_ = rpc::InternalPubSubGcsService::NewStub(channel_);
return Status::OK();
}
Expand Down
4 changes: 0 additions & 4 deletions src/ray/gcs/pubsub/gcs_pub_sub.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,5 @@ class RAY_EXPORT PythonGcsPublisher {
int gcs_port_;
};

/// Construct the arguments for synchronous gRPC clients
/// (the ones wrapped in Python)
grpc::ChannelArguments PythonGrpcChannelArguments();

} // namespace gcs
} // namespace ray
72 changes: 72 additions & 0 deletions src/ray/rpc/gcs_server/gcs_rpc_client.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright 2023 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ray/rpc/gcs_server/gcs_rpc_client.h"

namespace ray {
namespace rpc {
grpc::ChannelArguments GetGcsRpcClientArguments() {
grpc::ChannelArguments arguments = CreateDefaultChannelArguments();
arguments.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_max_reconnect_backoff_ms());
arguments.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_min_reconnect_backoff_ms());
arguments.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_initial_reconnect_backoff_ms());
return arguments;
}

std::shared_ptr<grpc::Channel> GcsRpcClient::GetDefaultChannel(const std::string &address,
int port) {
static std::mutex mu_;
static std::shared_ptr<grpc::Channel> channel_;
static std::string address_;
static int port_ = 0;

// Don't reuse channel if proxy or tls is set
// TODO: Reuse the channel even it's tls.
// Right now, if we do this, python/ray/serve/tests/test_grpc.py
// will fail.
if (::RayConfig::instance().grpc_enable_http_proxy() ||
::RayConfig::instance().USE_TLS()) {
return BuildChannel(address, port, GetGcsRpcClientArguments());
}

std::lock_guard<std::mutex> guard(mu_);
if (channel_ == nullptr || (address_ != address || port_ != port)) {
address_ = address;
fishbone marked this conversation as resolved.
Show resolved Hide resolved
port_ = port;

// This condition shouldn't happen in most cases. It could only happen when
// ray driver wanted to talk with different GCS.
// - This mostly happens in testing, where the test main process is the driver.
// It calls ray.init and then ray.shutdown and later ray.init with a different
// GCS address.
// - Potentially it can also happen in the user's driver where there are two
// ray clusters and the user ray.init and ray.shutdown and then tries to
// connect to a different GCS.
if (channel_ != nullptr) {
RAY_LOG(WARNING) << "Generate a new GCS channel: " << address << ":" << port
<< ". Potentially it will increase GCS socket numbers."
<< " This could only happen in testing or in the same driver "
<< " it tries to connect to different GCS clusters.";
}
channel_ = BuildChannel(address, port, GetGcsRpcClientArguments());
}

return channel_;
}

} // namespace rpc
} // namespace ray
15 changes: 5 additions & 10 deletions src/ray/rpc/gcs_server/gcs_rpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ class Executor {

/// Client used for communicating with gcs server.
class GcsRpcClient {
public:
static std::shared_ptr<grpc::Channel> GetDefaultChannel(const std::string &address,
int port);

public:
/// Constructor. GcsRpcClient is not thread safe.
///
Expand All @@ -190,16 +194,7 @@ class GcsRpcClient {
gcs_port_(port),
io_context_(&client_call_manager.GetMainService()),
timer_(std::make_unique<boost::asio::deadline_timer>(*io_context_)) {
grpc::ChannelArguments arguments = CreateDefaultChannelArguments();
arguments.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_max_reconnect_backoff_ms());
arguments.SetInt(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_min_reconnect_backoff_ms());
arguments.SetInt(GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS,
::RayConfig::instance().gcs_grpc_initial_reconnect_backoff_ms());

channel_ = BuildChannel(address, port, arguments);

channel_ = GetDefaultChannel(address, port);
// If not the reconnection will continue to work.
auto deadline =
std::chrono::system_clock::now() +
Expand Down