diff --git a/BUILD.bazel b/BUILD.bazel index 0ce06ebc8116..37a7b000e875 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -546,6 +546,7 @@ cc_library( ":scheduler", ":worker_rpc", "@boost//:bimap", + "@com_github_grpc_grpc//src/proto/grpc/health/v1:health_proto", "@com_google_absl//absl/container:btree", ], ) @@ -1810,6 +1811,20 @@ cc_library( ], ) +cc_test( + name = "gcs_health_check_manager_test", + size = "small", + srcs = [ + "src/ray/gcs/gcs_server/test/gcs_health_check_manager_test.cc", + ], + copts = COPTS, + tags = ["team:core"], + deps = [ + ":gcs_server_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "gcs_node_manager_test", size = "small", diff --git a/dashboard/modules/healthz/tests/test_healthz.py b/dashboard/modules/healthz/tests/test_healthz.py index 49574ad3f657..727e1ac8a7bd 100644 --- a/dashboard/modules/healthz/tests/test_healthz.py +++ b/dashboard/modules/healthz/tests/test_healthz.py @@ -7,7 +7,9 @@ from ray._private.test_utils import find_free_port, wait_for_condition -def test_healthz_head(ray_start_cluster): +@pytest.mark.parametrize("pull_based", [True, False]) +def test_healthz_head(pull_based, monkeypatch, ray_start_cluster): + monkeypatch.setenv("RAY_pull_based_healthcheck", "true" if pull_based else "false") dashboard_port = find_free_port() h = ray_start_cluster.add_node(dashboard_port=dashboard_port) uri = f"http://localhost:{dashboard_port}/api/gcs_healthz" @@ -20,7 +22,9 @@ def test_healthz_head(ray_start_cluster): assert "Read timed out" in str(e) -def test_healthz_agent_1(ray_start_cluster): +@pytest.mark.parametrize("pull_based", [True, False]) +def test_healthz_agent_1(pull_based, monkeypatch, ray_start_cluster): + monkeypatch.setenv("RAY_pull_based_healthcheck", "true" if pull_based else "false") agent_port = find_free_port() h = ray_start_cluster.add_node(dashboard_agent_listen_port=agent_port) uri = f"http://localhost:{agent_port}/api/local_raylet_healthz" @@ -32,9 +36,17 @@ def test_healthz_agent_1(ray_start_cluster): assert requests.get(uri).status_code == 200 +@pytest.mark.parametrize("pull_based", [True, False]) @pytest.mark.skipif(sys.platform == "win32", reason="SIGSTOP only on posix") -def test_healthz_agent_2(monkeypatch, ray_start_cluster): - monkeypatch.setenv("RAY_num_heartbeats_timeout", "3") +def test_healthz_agent_2(pull_based, monkeypatch, ray_start_cluster): + monkeypatch.setenv("RAY_pull_based_healthcheck", "true" if pull_based else "false") + if pull_based: + monkeypatch.setenv("RAY_health_check_failure_threshold", "3") + monkeypatch.setenv("RAY_health_check_timeout_ms", "100") + monkeypatch.setenv("RAY_health_check_period_ms", "1000") + monkeypatch.setenv("RAY_health_check_initial_delay_ms", "0") + else: + monkeypatch.setenv("RAY_num_heartbeats_timeout", "3") agent_port = find_free_port() h = ray_start_cluster.add_node(dashboard_agent_listen_port=agent_port) diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index a3250f29894d..bf75ac0a07fc 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -81,4 +81,14 @@ cdef extern from "ray/common/ray_config.h" nogil: c_string REDIS_SERVER_NAME() const + c_bool pull_based_healthcheck() const + + int64_t health_check_initial_delay_ms() const + + int64_t health_check_period_ms() const + + int64_t health_check_timeout_ms() const + + int64_t health_check_failure_threshold() const + uint64_t memory_monitor_interval_ms() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index a6a69701bdfc..d68703192c79 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -137,6 +137,26 @@ cdef class Config: def REDIS_SERVER_NAME(): return RayConfig.instance().REDIS_SERVER_NAME() + @staticmethod + def pull_based_healthcheck(): + return RayConfig.instance().pull_based_healthcheck() + + @staticmethod + def health_check_initial_delay_ms(): + return RayConfig.instance().health_check_initial_delay_ms() + + @staticmethod + def health_check_period_ms(): + return RayConfig.instance().health_check_period_ms() + + @staticmethod + def health_check_timeout_ms(): + return RayConfig.instance().health_check_timeout_ms() + + @staticmethod + def health_check_failure_threshold(): + return RayConfig.instance().health_check_failure_threshold() + @staticmethod def memory_monitor_interval_ms(): return (RayConfig.instance() diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 92ea5260cbd1..a2acc327bcd9 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -562,12 +562,24 @@ async def f(self): { "num_cpus": 0, "_system_config": { + "pull_based_healthcheck": False, "raylet_death_check_interval_milliseconds": 10 * 1000, "num_heartbeats_timeout": 10, "raylet_heartbeat_period_milliseconds": 100, "timeout_ms_task_wait_for_death_info": 100, }, - } + }, + { + "num_cpus": 0, + "_system_config": { + "pull_based_healthcheck": True, + "raylet_death_check_interval_milliseconds": 10 * 1000, + "health_check_initial_delay_ms": 0, + "health_check_failure_threshold": 10, + "health_check_period_ms": 100, + "timeout_ms_task_wait_for_death_info": 100, + }, + }, ], indirect=True, ) diff --git a/python/ray/tests/test_metrics_agent.py b/python/ray/tests/test_metrics_agent.py index 56f3fddabd9d..9963e372dcfd 100644 --- a/python/ray/tests/test_metrics_agent.py +++ b/python/ray/tests/test_metrics_agent.py @@ -51,7 +51,6 @@ "ray_object_directory_lookups", "ray_object_directory_added_locations", "ray_object_directory_removed_locations", - "ray_heartbeat_report_ms_sum", "ray_process_startup_time_ms_sum", "ray_internal_num_processes_started", "ray_internal_num_spilled_tasks", @@ -148,6 +147,11 @@ "ray_component_uss_mb", ] +if ray._raylet.Config.pull_based_healthcheck(): + _METRICS.append("ray_health_check_rpc_latency_ms_sum") +else: + _METRICS.append("ray_heartbeat_report_ms_sum") + @pytest.fixture def _setup_cluster_for_test(request, ray_start_cluster): diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index 8cd25b92a777..d65b7937f68f 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -42,8 +42,17 @@ def test_shutdown(): "ray_start_cluster_head", [ generate_system_config_map( - num_heartbeats_timeout=3, object_timeout_milliseconds=12345 - ) + num_heartbeats_timeout=3, + object_timeout_milliseconds=12345, + pull_based_healthcheck=False, + ), + generate_system_config_map( + health_check_initial_delay_ms=0, + health_check_period_ms=1000, + health_check_failure_threshold=3, + object_timeout_milliseconds=12345, + pull_based_healthcheck=True, + ), ], indirect=True, ) @@ -62,7 +71,12 @@ def test_system_config(ray_start_cluster_head): @ray.remote def f(): assert ray._config.object_timeout_milliseconds() == 12345 - assert ray._config.num_heartbeats_timeout() == 3 + if ray._config.pull_based_healthcheck(): + assert ray._config.health_check_initial_delay_ms() == 0 + assert ray._config.health_check_failure_threshold() == 3 + assert ray._config.health_check_period_ms() == 1000 + else: + assert ray._config.num_heartbeats_timeout() == 3 ray.get([f.remote() for _ in range(5)]) diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 4e4ff70ff47c..58e0e6b90329 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -699,3 +699,11 @@ RAY_CONFIG(std::string, REDIS_SERVER_NAME, "") // The delay is a random number between the interval. If method equals '*', // it will apply to all methods. RAY_CONFIG(std::string, testing_asio_delay_us, "") + +/// A feature flag to enable pull based health check. +/// TODO: Turn it on by default +RAY_CONFIG(bool, pull_based_healthcheck, false) +RAY_CONFIG(int64_t, health_check_initial_delay_ms, 5000) +RAY_CONFIG(int64_t, health_check_period_ms, 3000) +RAY_CONFIG(int64_t, health_check_timeout_ms, 10000) +RAY_CONFIG(int64_t, health_check_failure_threshold, 5) diff --git a/src/ray/gcs/gcs_client/test/gcs_client_reconnection_test.cc b/src/ray/gcs/gcs_client/test/gcs_client_reconnection_test.cc index ec3f479eb192..5b70732392e2 100644 --- a/src/ray/gcs/gcs_client/test/gcs_client_reconnection_test.cc +++ b/src/ray/gcs/gcs_client/test/gcs_client_reconnection_test.cc @@ -68,14 +68,14 @@ class GcsClientReconnectionTest : public ::testing::Test { auto channel = grpc::CreateChannel(absl::StrCat("127.0.0.1:", config_.grpc_server_port), grpc::InsecureChannelCredentials()); - std::unique_ptr stub = - rpc::NodeInfoGcsService::NewStub(std::move(channel)); + auto stub = grpc::health::v1::Health::NewStub(channel); grpc::ClientContext context; context.set_deadline(std::chrono::system_clock::now() + 1s); - const rpc::CheckAliveRequest request; - rpc::CheckAliveReply reply; - auto status = stub->CheckAlive(&context, request, &reply); - if (!status.ok()) { + ::grpc::health::v1::HealthCheckRequest request; + ::grpc::health::v1::HealthCheckResponse reply; + auto status = stub->Check(&context, request, &reply); + if (!status.ok() || + reply.status() != ::grpc::health::v1::HealthCheckResponse::SERVING) { RAY_LOG(WARNING) << "Unable to reach GCS: " << status.error_code() << " " << status.error_message(); return false; diff --git a/src/ray/gcs/gcs_server/gcs_health_check_manager.cc b/src/ray/gcs/gcs_server/gcs_health_check_manager.cc new file mode 100644 index 000000000000..831f345ef81a --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_health_check_manager.cc @@ -0,0 +1,144 @@ +// Copyright 2022 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/gcs/gcs_server/gcs_health_check_manager.h" + +#include "ray/stats/metric.h" +DEFINE_stats(health_check_rpc_latency_ms, + "Latency of rpc request for health check.", + (), + ({1, 10, 100, 1000, 10000}, ), + ray::stats::HISTOGRAM); + +namespace ray { +namespace gcs { + +GcsHealthCheckManager::GcsHealthCheckManager( + instrumented_io_context &io_service, + std::function on_node_death_callback, + int64_t initial_delay_ms, + int64_t timeout_ms, + int64_t period_ms, + int64_t failure_threshold) + : io_service_(io_service), + on_node_death_callback_(on_node_death_callback), + initial_delay_ms_(initial_delay_ms), + timeout_ms_(timeout_ms), + period_ms_(period_ms), + failure_threshold_(failure_threshold) { + RAY_CHECK(on_node_death_callback != nullptr); + RAY_CHECK(initial_delay_ms >= 0); + RAY_CHECK(timeout_ms >= 0); + RAY_CHECK(period_ms >= 0); + RAY_CHECK(failure_threshold >= 0); +} + +GcsHealthCheckManager::~GcsHealthCheckManager() {} + +void GcsHealthCheckManager::RemoveNode(const NodeID &node_id) { + io_service_.dispatch( + [this, node_id]() { + auto iter = health_check_contexts_.find(node_id); + if (iter == health_check_contexts_.end()) { + return; + } + health_check_contexts_.erase(iter); + }, + "GcsHealthCheckManager::RemoveNode"); +} + +void GcsHealthCheckManager::FailNode(const NodeID &node_id) { + RAY_LOG(WARNING) << "Node " << node_id << " is dead because the health check failed."; + on_node_death_callback_(node_id); + health_check_contexts_.erase(node_id); +} + +std::vector GcsHealthCheckManager::GetAllNodes() const { + std::vector nodes; + for (const auto &[node_id, _] : health_check_contexts_) { + nodes.emplace_back(node_id); + } + return nodes; +} + +void GcsHealthCheckManager::HealthCheckContext::StartHealthCheck() { + using ::grpc::health::v1::HealthCheckResponse; + + context_ = std::make_shared(); + + auto deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(manager_->timeout_ms_); + context_->set_deadline(deadline); + stub_->async()->Check( + context_.get(), + &request_, + &response_, + [this, stopped = this->stopped_, context = this->context_, now = absl::Now()]( + ::grpc::Status status) { + // This callback is done in gRPC's thread pool. + STATS_health_check_rpc_latency_ms.Record( + absl::ToInt64Milliseconds(absl::Now() - now)); + if (status.error_code() == ::grpc::StatusCode::CANCELLED) { + return; + } + manager_->io_service_.post( + [this, stopped, status]() { + // Stopped has to be read in the same thread where it's updated. + if (*stopped) { + return; + } + RAY_LOG(DEBUG) << "Health check status: " << int(response_.status()); + + if (status.ok() && response_.status() == HealthCheckResponse::SERVING) { + // Health check passed + health_check_remaining_ = manager_->failure_threshold_; + } else { + --health_check_remaining_; + RAY_LOG(WARNING) << "Health check failed for node " << node_id_ + << ", remaining checks " << health_check_remaining_; + } + + if (health_check_remaining_ == 0) { + manager_->io_service_.post([this]() { manager_->FailNode(node_id_); }, + ""); + } else { + // Do another health check. + timer_.expires_from_now( + boost::posix_time::milliseconds(manager_->period_ms_)); + timer_.async_wait([this, stopped](auto ec) { + // We need to check stopped here as well since cancel + // won't impact the queued tasks. + if (ec != boost::asio::error::operation_aborted && !*stopped) { + StartHealthCheck(); + } + }); + } + }, + "HealthCheck"); + }); +} + +void GcsHealthCheckManager::AddNode(const NodeID &node_id, + std::shared_ptr channel) { + io_service_.dispatch( + [this, channel, node_id]() { + RAY_CHECK(health_check_contexts_.count(node_id) == 0); + auto context = std::make_unique(this, channel, node_id); + health_check_contexts_.emplace(std::make_pair(node_id, std::move(context))); + }, + "GcsHealthCheckManager::AddNode"); +} + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_health_check_manager.h b/src/ray/gcs/gcs_server/gcs_health_check_manager.h new file mode 100644 index 000000000000..b504ce079343 --- /dev/null +++ b/src/ray/gcs/gcs_server/gcs_health_check_manager.h @@ -0,0 +1,161 @@ +// Copyright 2022 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "absl/container/flat_hash_map.h" +#include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/id.h" +#include "ray/common/ray_config.h" +#include "src/proto/grpc/health/v1/health.grpc.pb.h" + +class GcsHealthCheckManagerTest; + +namespace ray { +namespace gcs { + +/// GcsHealthCheckManager is used to track the healthiness of the nodes in the ray +/// cluster. The health check is done in pull based way, which means this module will send +/// health check to the raylets to see whether the raylet is healthy or not. If the raylet +/// is not healthy for certain times, the module will think the raylet is dead. +/// When the node is dead a callback passed in the constructor will be called and this +/// node will be removed from GcsHealthCheckManager. The node can be added into this class +/// later. Although the same node id is not supposed to be reused in ray cluster, this is +/// not enforced in this class. +/// TODO (iycheng): Move the GcsHealthCheckManager to ray/common. +class GcsHealthCheckManager { + public: + /// Constructor of GcsHealthCheckManager. + /// + /// \param io_service The thread where all operations in this class should run. + /// \param on_node_death_callback The callback function when some node is marked as + /// failure. + /// \param initial_delay_ms The delay for the first health check. + /// \param period_ms The interval between two health checks for the same node. + /// \param failure_threshold The threshold before a node will be marked as dead due to + /// health check failure. + GcsHealthCheckManager( + instrumented_io_context &io_service, + std::function on_node_death_callback, + int64_t initial_delay_ms = RayConfig::instance().health_check_initial_delay_ms(), + int64_t timeout_ms = RayConfig::instance().health_check_timeout_ms(), + int64_t period_ms = RayConfig::instance().health_check_period_ms(), + int64_t failure_threshold = RayConfig::instance().health_check_failure_threshold()); + + ~GcsHealthCheckManager(); + + /// Start to track the healthiness of a node. + /// + /// \param node_id The id of the node. + /// \param channel The gRPC channel to the node. + void AddNode(const NodeID &node_id, std::shared_ptr channel); + + /// Stop tracking the healthiness of a node. + /// + /// \param node_id The id of the node to stop tracking. + void RemoveNode(const NodeID &node_id); + + /// Return all the nodes monitored. + /// + /// \return A list of node id which are being monitored by this class. + std::vector GetAllNodes() const; + + private: + /// Fail a node when health check failed. It'll stop the health checking and + /// call on_node_death_callback. + /// + /// \param node_id The id of the node. + void FailNode(const NodeID &node_id); + + using Timer = boost::asio::deadline_timer; + + /// The context for the health check. It's to support unary call. + /// It can be updated to support streaming call for efficiency. + class HealthCheckContext { + public: + HealthCheckContext(GcsHealthCheckManager *manager, + std::shared_ptr channel, + NodeID node_id) + : manager_(manager), + node_id_(node_id), + stopped_(std::make_shared(false)), + timer_(manager->io_service_), + health_check_remaining_(manager->failure_threshold_) { + stub_ = grpc::health::v1::Health::NewStub(channel); + timer_.expires_from_now( + boost::posix_time::milliseconds(manager_->initial_delay_ms_)); + timer_.async_wait([this](auto ec) { + if (ec != boost::asio::error::operation_aborted) { + StartHealthCheck(); + } + }); + } + + ~HealthCheckContext() { + timer_.cancel(); + if (context_ != nullptr) { + context_->TryCancel(); + } + *stopped_ = true; + } + + private: + void StartHealthCheck(); + + GcsHealthCheckManager *manager_; + + NodeID node_id_; + + // Whether the health check has stopped. + std::shared_ptr stopped_; + + /// gRPC related fields + std::unique_ptr<::grpc::health::v1::Health::Stub> stub_; + + // The context is used in the gRPC callback which is in another + // thread, so we need it to be a shared_ptr. + std::shared_ptr context_; + ::grpc::health::v1::HealthCheckRequest request_; + ::grpc::health::v1::HealthCheckResponse response_; + + /// The timer is used to do async wait before the next try. + Timer timer_; + + /// The remaining check left. If it reaches 0, the node will be marked as dead. + int64_t health_check_remaining_; + }; + + /// The main service. All method needs to run on this thread. + instrumented_io_context &io_service_; + + /// Callback when the node failed. + std::function on_node_death_callback_; + + /// The context of the health check for each nodes. + absl::flat_hash_map> health_check_contexts_; + + /// The delay for the first health check request. + const int64_t initial_delay_ms_; + /// Timeout for each health check request. + const int64_t timeout_ms_; + /// Intervals between two health check. + const int64_t period_ms_; + /// The number of failures before the node is considered as dead. + const int64_t failure_threshold_; +}; + +} // namespace gcs +} // namespace ray diff --git a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h index c788e0de636f..2124ba669068 100644 --- a/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h +++ b/src/ray/gcs/gcs_server/gcs_heartbeat_manager.h @@ -1,4 +1,3 @@ - // Copyright 2017 The Ray Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 10acb3e0fc98..4d3b3582cff7 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -178,7 +178,12 @@ void GcsServer::DoStart(const GcsInitData &gcs_init_data) { // be run. Otherwise the node failure detector will mistake // some living nodes as dead as the timer inside node failure // detector is already run. - gcs_heartbeat_manager_->Start(); + if (gcs_heartbeat_manager_) { + gcs_heartbeat_manager_->Start(); + } + RAY_CHECK(int(gcs_heartbeat_manager_ != nullptr) + + int(gcs_healthcheck_manager_ != nullptr) == + 1); RecordMetrics(); @@ -212,7 +217,9 @@ void GcsServer::Stop() { // GcsHeartbeatManager is still checking nodes' heartbeat timeout. Since RPC Server // won't handle heartbeat calls anymore, some nodes will be marked as dead during this // time, causing many nodes die after GCS's failure. - gcs_heartbeat_manager_->Stop(); + if (gcs_heartbeat_manager_) { + gcs_heartbeat_manager_->Stop(); + } if (RayConfig::instance().use_ray_syncer()) { ray_syncer_io_context_.stop(); ray_syncer_thread_->join(); @@ -245,19 +252,35 @@ void GcsServer::InitGcsNodeManager(const GcsInitData &gcs_init_data) { void GcsServer::InitGcsHeartbeatManager(const GcsInitData &gcs_init_data) { RAY_CHECK(gcs_node_manager_); - gcs_heartbeat_manager_ = std::make_shared( - heartbeat_manager_io_service_, /*on_node_death_callback=*/ - [this](const NodeID &node_id) { - main_service_.post( - [this, node_id] { return gcs_node_manager_->OnNodeFailure(node_id); }, - "GcsServer.NodeDeathCallback"); - }); - // Initialize by gcs tables data. - gcs_heartbeat_manager_->Initialize(gcs_init_data); - // Register service. - heartbeat_info_service_.reset(new rpc::HeartbeatInfoGrpcService( - heartbeat_manager_io_service_, *gcs_heartbeat_manager_)); - rpc_server_.RegisterService(*heartbeat_info_service_); + auto node_death_callback = [this](const NodeID &node_id) { + main_service_.post( + [this, node_id] { return gcs_node_manager_->OnNodeFailure(node_id); }, + "GcsServer.NodeDeathCallback"); + }; + + if (RayConfig::instance().pull_based_healthcheck()) { + gcs_healthcheck_manager_ = + std::make_unique(main_service_, node_death_callback); + for (const auto &item : gcs_init_data.Nodes()) { + if (item.second.state() == rpc::GcsNodeInfo::ALIVE) { + rpc::Address remote_address; + remote_address.set_raylet_id(item.second.node_id()); + remote_address.set_ip_address(item.second.node_manager_address()); + remote_address.set_port(item.second.node_manager_port()); + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(remote_address); + gcs_healthcheck_manager_->AddNode(item.first, raylet_client->GetChannel()); + } + } + } else { + gcs_heartbeat_manager_ = std::make_shared( + heartbeat_manager_io_service_, /*on_node_death_callback=*/node_death_callback); + // Initialize by gcs tables data. + gcs_heartbeat_manager_->Initialize(gcs_init_data); + // Register service. + heartbeat_info_service_.reset(new rpc::HeartbeatInfoGrpcService( + heartbeat_manager_io_service_, *gcs_heartbeat_manager_)); + rpc_server_.RegisterService(*heartbeat_info_service_); + } } void GcsServer::InitGcsResourceManager(const GcsInitData &gcs_init_data) { @@ -606,15 +629,26 @@ void GcsServer::InstallEventListeners() { gcs_resource_manager_->OnNodeAdd(*node); gcs_placement_group_manager_->OnNodeAdd(node_id); gcs_actor_manager_->SchedulePendingActors(); - gcs_heartbeat_manager_->AddNode(*node); + if (gcs_heartbeat_manager_) { + gcs_heartbeat_manager_->AddNode(*node); + } + + rpc::Address address; + address.set_raylet_id(node->node_id()); + address.set_ip_address(node->node_manager_address()); + address.set_port(node->node_manager_port()); + + auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(address); + + if (gcs_healthcheck_manager_) { + RAY_CHECK(raylet_client != nullptr); + auto channel = raylet_client->GetChannel(); + RAY_CHECK(channel != nullptr); + gcs_healthcheck_manager_->AddNode(node_id, channel); + } cluster_task_manager_->ScheduleAndDispatchTasks(); - if (RayConfig::instance().use_ray_syncer()) { - rpc::Address address; - address.set_raylet_id(node->node_id()); - address.set_ip_address(node->node_manager_address()); - address.set_port(node->node_manager_port()); - auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(address); + if (RayConfig::instance().use_ray_syncer()) { ray_syncer_->Connect(raylet_client->GetChannel()); } else { gcs_ray_syncer_->AddNode(*node); @@ -630,7 +664,14 @@ void GcsServer::InstallEventListeners() { gcs_placement_group_manager_->OnNodeDead(node_id); gcs_actor_manager_->OnNodeDead(node_id, node_ip_address); raylet_client_pool_->Disconnect(node_id); - gcs_heartbeat_manager_->RemoveNode(node_id); + if (gcs_heartbeat_manager_) { + gcs_heartbeat_manager_->RemoveNode(node_id); + } + + if (gcs_healthcheck_manager_) { + gcs_healthcheck_manager_->RemoveNode(node_id); + } + if (RayConfig::instance().use_ray_syncer()) { ray_syncer_->Disconnect(node_id.Binary()); } else { diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index b99ee6895b76..a84354fc3b4c 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -18,6 +18,7 @@ #include "ray/common/ray_syncer/ray_syncer.h" #include "ray/common/runtime_env_manager.h" #include "ray/gcs/gcs_server/gcs_function_manager.h" +#include "ray/gcs/gcs_server/gcs_health_check_manager.h" #include "ray/gcs/gcs_server/gcs_heartbeat_manager.h" #include "ray/gcs/gcs_server/gcs_init_data.h" #include "ray/gcs/gcs_server/gcs_kv_manager.h" @@ -201,6 +202,8 @@ class GcsServer { std::shared_ptr cluster_task_manager_; /// The gcs node manager. std::shared_ptr gcs_node_manager_; + /// The health check manager. + std::shared_ptr gcs_healthcheck_manager_; /// The heartbeat manager. std::shared_ptr gcs_heartbeat_manager_; /// The gcs redis failure detector. diff --git a/src/ray/gcs/gcs_server/test/gcs_health_check_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_health_check_manager_test.cc new file mode 100644 index 000000000000..d28e263008da --- /dev/null +++ b/src/ray/gcs/gcs_server/test/gcs_health_check_manager_test.cc @@ -0,0 +1,227 @@ +// Copyright 2020-2021 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include +#include + +using namespace boost; +#include + +#include +#include + +#include "gtest/gtest.h" +#include "ray/gcs/gcs_server/gcs_health_check_manager.h" + +using namespace ray; +using namespace std::literals::chrono_literals; + +class GcsHealthCheckManagerTest : public ::testing::Test { + protected: + GcsHealthCheckManagerTest() {} + void SetUp() override { + grpc::EnableDefaultHealthCheckService(true); + + health_check = std::make_unique( + io_service, + [this](const NodeID &id) { dead_nodes.insert(id); }, + initial_delay_ms, + timeout_ms, + period_ms, + failure_threshold); + port = 10000; + } + + void TearDown() override { + io_service.poll(); + io_service.stop(); + + // Stop the servers. + for (auto [_, server] : servers) { + server->Shutdown(); + } + + // Allow gRPC to cleanup. + boost::this_thread::sleep_for(boost::chrono::seconds(2)); + } + + NodeID AddServer() { + std::promise port_promise; + auto node_id = NodeID::FromRandom(); + + auto server = std::make_shared(node_id.Hex(), port, true); + + auto channel = grpc::CreateChannel("localhost:" + std::to_string(port), + grpc::InsecureChannelCredentials()); + server->Run(); + servers.emplace(node_id, server); + health_check->AddNode(node_id, channel); + ++port; + return node_id; + } + + void StopServing(const NodeID &id) { + auto iter = servers.find(id); + RAY_CHECK(iter != servers.end()); + iter->second->GetServer().GetHealthCheckService()->SetServingStatus(false); + } + + void StartServing(const NodeID &id) { + auto iter = servers.find(id); + RAY_CHECK(iter != servers.end()); + iter->second->GetServer().GetHealthCheckService()->SetServingStatus(true); + } + + void DeleteServer(const NodeID &id) { + auto iter = servers.find(id); + if (iter != servers.end()) { + iter->second->Shutdown(); + servers.erase(iter); + } + } + + void Run(size_t n = 1) { + // If n == 0 it mean we just run it and return. + if (n == 0) { + io_service.run(); + io_service.restart(); + return; + } + + while (n) { + auto i = io_service.run_one(); + n -= i; + io_service.restart(); + } + } + + int port; + instrumented_io_context io_service; + std::unique_ptr health_check; + std::unordered_map> servers; + std::unordered_set dead_nodes; + const int64_t initial_delay_ms = 1000; + const int64_t timeout_ms = 1000; + const int64_t period_ms = 1000; + const int64_t failure_threshold = 5; +}; + +TEST_F(GcsHealthCheckManagerTest, TestBasic) { + auto node_id = AddServer(); + Run(0); // Initial run + ASSERT_TRUE(dead_nodes.empty()); + + // Run the first health check + Run(); + ASSERT_TRUE(dead_nodes.empty()); + + Run(2); // One for starting RPC and one for the RPC callback. + ASSERT_TRUE(dead_nodes.empty()); + StopServing(node_id); + + for (auto i = 0; i < failure_threshold; ++i) { + Run(2); // One for starting RPC and one for the RPC callback. + } + + Run(); // For failure callback. + + ASSERT_EQ(1, dead_nodes.size()); + ASSERT_TRUE(dead_nodes.count(node_id)); +} + +TEST_F(GcsHealthCheckManagerTest, StoppedAndResume) { + auto node_id = AddServer(); + Run(0); // Initial run + ASSERT_TRUE(dead_nodes.empty()); + + // Run the first health check + Run(); + ASSERT_TRUE(dead_nodes.empty()); + + Run(2); // One for starting RPC and one for the RPC callback. + ASSERT_TRUE(dead_nodes.empty()); + StopServing(node_id); + + for (auto i = 0; i < failure_threshold; ++i) { + Run(2); // One for starting RPC and one for the RPC callback. + if (i == failure_threshold / 2) { + StartServing(node_id); + } + } + + Run(); // For failure callback. + + ASSERT_EQ(0, dead_nodes.size()); +} + +TEST_F(GcsHealthCheckManagerTest, Crashed) { + auto node_id = AddServer(); + Run(0); // Initial run + ASSERT_TRUE(dead_nodes.empty()); + + // Run the first health check + Run(); + ASSERT_TRUE(dead_nodes.empty()); + + Run(2); // One for starting RPC and one for the RPC callback. + ASSERT_TRUE(dead_nodes.empty()); + + // Check it again + Run(2); // One for starting RPC and one for the RPC callback. + ASSERT_TRUE(dead_nodes.empty()); + + DeleteServer(node_id); + + for (auto i = 0; i < failure_threshold; ++i) { + Run(2); // One for starting RPC and one for the RPC callback. + } + + Run(); // For failure callback. + + ASSERT_EQ(1, dead_nodes.size()); + ASSERT_TRUE(dead_nodes.count(node_id)); +} + +TEST_F(GcsHealthCheckManagerTest, NodeRemoved) { + auto node_id = AddServer(); + Run(0); // Initial run + ASSERT_TRUE(dead_nodes.empty()); + + // Run the first health check + Run(); + ASSERT_TRUE(dead_nodes.empty()); + + Run(2); // One for starting RPC and one for the RPC callback. + ASSERT_TRUE(dead_nodes.empty()); + health_check->RemoveNode(node_id); + + // Make sure it's not monitored any more + for (auto i = 0; i < failure_threshold; ++i) { + io_service.poll(); + } + + ASSERT_EQ(0, dead_nodes.size()); + ASSERT_EQ(0, health_check->GetAllNodes().size()); +} + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc index 5ea69fa41ee3..61cb5c8a06e4 100644 --- a/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_server_rpc_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" #include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/ray_config.h" #include "ray/common/test_util.h" #include "ray/gcs/gcs_server/gcs_server.h" #include "ray/gcs/test/gcs_test_util.h" @@ -319,10 +320,12 @@ TEST_F(GcsServerTest, TestNodeInfo) { ASSERT_TRUE(node_info_list[0].state() == rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_ALIVE); - // Report heartbeat - rpc::ReportHeartbeatRequest report_heartbeat_request; - report_heartbeat_request.mutable_heartbeat()->set_node_id(gcs_node_info->node_id()); - ASSERT_TRUE(ReportHeartbeat(report_heartbeat_request)); + if (!RayConfig::instance().pull_based_healthcheck()) { + // Report heartbeat + rpc::ReportHeartbeatRequest report_heartbeat_request; + report_heartbeat_request.mutable_heartbeat()->set_node_id(gcs_node_info->node_id()); + ASSERT_TRUE(ReportHeartbeat(report_heartbeat_request)); + } // Unregister node info rpc::DrainNodeRequest unregister_node_info_request; @@ -336,6 +339,9 @@ TEST_F(GcsServerTest, TestNodeInfo) { } TEST_F(GcsServerTest, TestHeartbeatWithNoRegistering) { + if (RayConfig::instance().pull_based_healthcheck()) { + GTEST_SKIP(); + } // Create gcs node info auto gcs_node_info = Mocker::GenNodeInfo(); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 7dbc5fee34be..24ed5e3f2bf2 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -170,6 +170,7 @@ void HeartbeatSender::Heartbeat() { RAY_CHECK_OK( gcs_client_->Nodes().AsyncReportHeartbeat(heartbeat_data, [](Status status) { if (status.IsDisconnected()) { + RAY_EVENT(FATAL, "RAYLET_MARKED_DEAD") << "This node has beem marked as dead."; RAY_LOG(FATAL) << "This node has beem marked as dead."; } })); @@ -462,7 +463,9 @@ NodeManager::NodeManager(instrumented_io_context &io_service, ray::Status NodeManager::RegisterGcs() { // Start sending heartbeat here to ensure it happening after raylet being registered. - heartbeat_sender_.reset(new HeartbeatSender(self_node_id_, gcs_client_)); + if (!RayConfig::instance().pull_based_healthcheck()) { + heartbeat_sender_.reset(new HeartbeatSender(self_node_id_, gcs_client_)); + } auto on_node_change = [this](const NodeID &node_id, const GcsNodeInfo &data) { if (data.state() == GcsNodeInfo::ALIVE) { NodeAdded(data); @@ -1026,6 +1029,14 @@ void NodeManager::NodeRemoved(const NodeID &node_id) { if (node_id == self_node_id_) { if (!is_node_drained_) { + // TODO(iycheng): Don't duplicate log here once we enable event by default. + RAY_EVENT(FATAL, "RAYLET_MARKED_DEAD") + << "[Timeout] Exiting because this node manager has mistakenly been marked as " + "dead by the " + << "GCS: GCS didn't receive heartbeats from this node for " + << RayConfig::instance().num_heartbeats_timeout() * + RayConfig::instance().raylet_heartbeat_period_milliseconds() + << " ms. This is likely because the machine or raylet has become overloaded."; RAY_LOG(FATAL) << "[Timeout] Exiting because this node manager has mistakenly been marked as " "dead by the " diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 67eb82aa7eb7..256e9aabe62f 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -118,6 +118,8 @@ class GrpcServer { void RegisterService(GrpcService &service); void RegisterService(grpc::Service &service); + grpc::Server &GetServer() { return *server_; } + protected: /// This function runs in a background thread. It keeps polling events from the /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances