diff --git a/BUILD.bazel b/BUILD.bazel index 7f0884f348aa..87e99d3397a8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -2631,6 +2631,7 @@ cc_test( copts = COPTS, tags = ["team:core"], deps = [ + ":grpc_common_lib", ":ray_common", ":ray_mock", "@com_google_googletest//:gtest", diff --git a/release/nightly_tests/many_nodes_tests/actor_test.py b/release/nightly_tests/many_nodes_tests/actor_test.py index 95c7475ae481..bc521fca09bb 100644 --- a/release/nightly_tests/many_nodes_tests/actor_test.py +++ b/release/nightly_tests/many_nodes_tests/actor_test.py @@ -16,11 +16,6 @@ def foo(self): return actors -def test_actor_ready(actors): - remaining = [actor.foo.remote() for actor in actors] - ray.get(remaining) - - def parse_script_args(): parser = argparse.ArgumentParser() parser.add_argument("--cpus-per-actor", type=float, default=0.2) @@ -43,7 +38,15 @@ def main(): sleep(10) return actor_ready_start = perf_counter() - test_actor_ready(actors) + total_actors = len(actors) + objs = [actor.foo.remote() for actor in actors] + + while len(objs) != 0: + objs_ready, objs = ray.wait(objs, timeout=10) + print( + f"Status: {total_actors - len(objs)}/{total_actors}, " + f"{perf_counter() - actor_ready_start}" + ) actor_ready_end = perf_counter() actor_ready_time = actor_ready_end - actor_ready_start diff --git a/src/mock/ray/common/ray_syncer/ray_syncer.h b/src/mock/ray/common/ray_syncer/ray_syncer.h index 0f768dab64c3..2ef430420697 100644 --- a/src/mock/ray/common/ray_syncer/ray_syncer.h +++ b/src/mock/ray/common/ray_syncer/ray_syncer.h @@ -43,10 +43,24 @@ class MockReceiverInterface : public ReceiverInterface { namespace ray { namespace syncer { -class MockNodeSyncConnection : public NodeSyncConnection { +class MockRaySyncerBidiReactor : public RaySyncerBidiReactor { public: - using NodeSyncConnection::NodeSyncConnection; - MOCK_METHOD(void, DoSend, (), (override)); + using RaySyncerBidiReactor::RaySyncerBidiReactor; + + MOCK_METHOD(void, Disconnect, (), (override)); + + MOCK_METHOD(bool, + PushToSendingQueue, + (std::shared_ptr), + (override)); +}; + +template +class MockRaySyncerBidiReactorBase : public RaySyncerBidiReactorBase { + public: + using RaySyncerBidiReactorBase::RaySyncerBidiReactorBase; + + MOCK_METHOD(void, Disconnect, (), (override)); }; } // namespace syncer diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index a7c51ccce55a..444770769568 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -337,4 +337,7 @@ ID_OSTREAM_OPERATOR(ActorID); ID_OSTREAM_OPERATOR(TaskID); ID_OSTREAM_OPERATOR(ObjectID); ID_OSTREAM_OPERATOR(PlacementGroupID); + +const NodeID kGCSNodeID = NodeID::FromBinary(std::string(kUniqueIDSize, 0)); + } // namespace ray diff --git a/src/ray/common/id.h b/src/ray/common/id.h index efe6c8ed9774..a6c753a1de35 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -561,3 +561,7 @@ DEFINE_UNIQUE_ID(PlacementGroupID); #undef DEFINE_UNIQUE_ID } // namespace std + +namespace ray { +extern const NodeID kGCSNodeID; +} diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index fd01f2bb0471..eb93e48f520d 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -686,6 +686,10 @@ RAY_CONFIG(int64_t, grpc_client_keepalive_time_ms, 300000) /// grpc keepalive timeout for client. RAY_CONFIG(int64_t, grpc_client_keepalive_timeout_ms, 120000) +/// grpc streaming buffer size +/// Set it to 512kb +RAY_CONFIG(int64_t, grpc_stream_buffer_size, 512 * 1024); + /// Whether to use log reporter in event framework RAY_CONFIG(bool, event_log_reporter_enabled, false) diff --git a/src/ray/common/ray_syncer/ray_syncer-inl.h b/src/ray/common/ray_syncer/ray_syncer-inl.h index 506ab14b6d9f..7726bc3da654 100644 --- a/src/ray/common/ray_syncer/ray_syncer-inl.h +++ b/src/ray/common/ray_syncer/ray_syncer-inl.h @@ -79,12 +79,48 @@ class NodeState { cluster_view_; }; -class NodeSyncConnection { +/// This is the base class for the bidi-streaming call and defined the method +/// needed. A bidi-stream for ray syncer needs to support pushing message and +/// disconnect from the remote node. +/// For the implementation, in the constructor, it needs to connect to the remote +/// node and it needs to implement the communication between the two nodes. +/// +/// Please refer to https://github.com/grpc/proposal/blob/master/L67-cpp-callback-api.md +/// for the callback API +/// +// clang-format off +/// For the server side: +/// grpc end (error or request) +/// +---------------------------------------------------------------+ +/// | v +/// +------------+ +-------------+ canceled by client +----------+ +--------+ +--------+ +/// | StartRead | <--> | OnReadDone | -----------------------------> | OnCancel | --> | Finish | --> | OnDone | +/// +------------+ +-------------+ +----------+ +--------+ +--------+ +/// canceled by client ^ ^ +/// +----------------------------------------------+ | +/// | | +/// +------------+ +-------------+ grpc end (error or request) | +/// | StartWrite | <--> | OnWriteDone | --------------------------------------------------+ +/// +------------+ +-------------+ +/// +/// +/// For the client side: +/// +------------+ +-------------+ +------------+ gRPC error or disconnected +--------+ +/// | StartCall | ---> | StartRead | <---> | OnReadDone | ----------------------------> | OnDone | +/// +------------+ +-------------+ +------------+ +--------+ +/// | ^ +/// | | +/// v | +/// +------------+ +-------------+ gRPC error or disconnected | +/// | StartWrite | <--> | OnWriteDone | -------------------------------------------------------+ +/// +------------+ +-------------+ +// clang-format on +class RaySyncerBidiReactor { public: - NodeSyncConnection( - instrumented_io_context &io_context, - std::string remote_node_id, - std::function)> message_processor); + RaySyncerBidiReactor(const std::string &remote_node_id) + : remote_node_id_(remote_node_id) {} + + virtual ~RaySyncerBidiReactor(){}; /// Push a message to the sending queue to be sent later. Some message /// might be dropped if the module think the target node has already got the @@ -94,38 +130,192 @@ class NodeSyncConnection { /// \param message The message to be sent. /// /// \return true if push to queue successfully. - bool PushToSendingQueue(std::shared_ptr message); - - /// Send the message queued. - virtual void DoSend() = 0; - - virtual ~NodeSyncConnection() {} + virtual bool PushToSendingQueue(std::shared_ptr message) = 0; /// Return the remote node id of this connection. const std::string &GetRemoteNodeID() const { return remote_node_id_; } + /// Disconnect will terminate the communication between local and remote node. + /// It also needs to do proper cleanup. + virtual void Disconnect() = 0; + + private: + std::string remote_node_id_; +}; + +/// This class implements the communication between two nodes except the initialization +/// and cleanup. +/// It keeps track of the message received and sent between two nodes and uses that to +/// deduplicate the messages. It also supports the batching for performance purposes. +template +class RaySyncerBidiReactorBase : public RaySyncerBidiReactor, public T { + public: + /// Constructor of RaySyncerBidiReactor. + /// + /// \param io_context The io context for the callback. + /// \param remote_node_id The node id connects to. + /// \param message_processor The callback for the message received. + /// \param cleanup_cb When the connection terminates, it'll be called to cleanup + /// the environment. + RaySyncerBidiReactorBase( + instrumented_io_context &io_context, + const std::string &remote_node_id, + std::function)> message_processor) + : RaySyncerBidiReactor(remote_node_id), + io_context_(io_context), + message_processor_(std::move(message_processor)) {} + + bool PushToSendingQueue(std::shared_ptr message) override { + // Try to filter out the messages the target node already has. + // Usually it'll be the case when the message is generated from the + // target node or it's sent from the target node. + // No need to resend the message sent from a node back. + if (message->node_id() == GetRemoteNodeID()) { + // Skip the message when it's about the node of this connection. + return false; + } + + auto &node_versions = GetNodeComponentVersions(message->node_id()); + if (node_versions[message->message_type()] < message->version()) { + node_versions[message->message_type()] = message->version(); + sending_buffer_[std::make_pair(message->node_id(), message->message_type())] = + std::move(message); + StartSend(); + return true; + } + return false; + } + + virtual ~RaySyncerBidiReactorBase() {} + + void StartPull() { + receiving_message_ = std::make_shared(); + RAY_LOG(DEBUG) << "Start reading: " << NodeID::FromBinary(GetRemoteNodeID()); + StartRead(receiving_message_.get()); + } + + protected: + /// The io context + instrumented_io_context &io_context_; + + private: /// Handle the udpates sent from the remote node. /// /// \param messages The message received. - void ReceiveUpdate(RaySyncMessages messages); + void ReceiveUpdate(std::shared_ptr message) { + auto &node_versions = GetNodeComponentVersions(message->node_id()); + RAY_LOG(DEBUG) << "Receive update: " + << " message_type=" << message->message_type() + << ", message_version=" << message->version() + << ", local_message_version=" + << node_versions[message->message_type()]; + if (node_versions[message->message_type()] < message->version()) { + node_versions[message->message_type()] = message->version(); + message_processor_(message); + } else { + RAY_LOG_EVERY_N(WARNING, 100) + << "Drop message received from " << NodeID::FromBinary(message->node_id()) + << " because the message version " << message->version() + << " is older than the local version " << node_versions[message->message_type()] + << ". Message type: " << message->message_type(); + } + } + + void SendNext() { + sending_ = false; + StartSend(); + } + + void StartSend() { + if (sending_) { + return; + } + + if (sending_buffer_.size() != 0) { + auto iter = sending_buffer_.begin(); + auto msg = std::move(iter->second); + sending_buffer_.erase(iter); + Send(std::move(msg), sending_buffer_.empty()); + sending_ = true; + } + } + + /// Sending a message to the remote node + /// + /// \param message The message to be sent + /// \param flush Whether to flush the sending queue in gRPC. + void Send(std::shared_ptr message, bool flush) { + sending_message_ = std::move(message); + grpc::WriteOptions opts; + if (flush) { + opts.clear_buffer_hint(); + } else { + opts.set_buffer_hint(); + } + RAY_LOG(DEBUG) << "[BidiReactor] Sending message to " + << NodeID::FromBinary(GetRemoteNodeID()) << " about node " + << NodeID::FromBinary(sending_message_->node_id()); + StartWrite(sending_message_.get(), opts); + } + + // Please refer to grpc callback api for the following four methods: + // https://github.com/grpc/proposal/blob/master/L67-cpp-callback-api.md + using T::StartRead; + using T::StartWrite; + + void OnWriteDone(bool ok) override { + if (ok) { + io_context_.dispatch([this]() { SendNext(); }, ""); + } else { + // No need to resent the message since if ok=false, it's the end + // of gRPC call and client will reconnect in case of a failure. + // In gRPC, OnDone will be called after. + RAY_LOG_EVERY_N(ERROR, 100) + << "Failed to send the message to: " << NodeID::FromBinary(GetRemoteNodeID()); + } + } + + void OnReadDone(bool ok) override { + if (ok) { + io_context_.dispatch( + [this, msg = std::move(receiving_message_)]() mutable { + RAY_CHECK(!msg->node_id().empty()); + ReceiveUpdate(std::move(msg)); + StartPull(); + }, + ""); + } else { + // No need to resent the message since if ok=false, it's the end + // of gRPC call and client will reconnect in case of a failure. + // In gRPC, OnDone will be called after. + RAY_LOG_EVERY_N(ERROR, 100) + << "Failed to read the message from: " << NodeID::FromBinary(GetRemoteNodeID()); + } + } + + /// grpc requests for sending and receiving + std::shared_ptr sending_message_; + std::shared_ptr receiving_message_; - protected: // For testing - FRIEND_TEST(RaySyncerTest, NodeSyncConnection); + FRIEND_TEST(RaySyncerTest, RaySyncerBidiReactorBase); friend struct SyncerServerTest; std::array &GetNodeComponentVersions( - const std::string &node_id); - - /// The io context - instrumented_io_context &io_context_; - - /// The remote node id. - std::string remote_node_id_; + const std::string &node_id) { + auto iter = node_versions_.find(node_id); + if (iter == node_versions_.end()) { + iter = node_versions_.emplace(node_id, std::array()) + .first; + iter->second.fill(-1); + } + return iter->second; + } /// Handler of a message update. - std::function)> message_processor_; + const std::function)> message_processor_; + private: /// Buffering all the updates. Sending will be done in an async way. absl::flat_hash_map, std::shared_ptr> @@ -136,58 +326,63 @@ class NodeSyncConnection { /// We'll filter the received or sent messages when the message is stale. absl::flat_hash_map> node_versions_; + + bool sending_ = false; }; -/// SyncConnection for gRPC server side. It has customized logic for sending. -class ServerSyncConnection : public NodeSyncConnection { +/// Reactor for gRPC server side. It defines the server's specific behavior for a +/// streaming call. +class RayServerBidiReactor : public RaySyncerBidiReactorBase { public: - ServerSyncConnection( + RayServerBidiReactor( + grpc::CallbackServerContext *server_context, instrumented_io_context &io_context, - const std::string &remote_node_id, - std::function)> message_processor); + const std::string &local_node_id, + std::function)> message_processor, + std::function cleanup_cb); - ~ServerSyncConnection() override; + ~RayServerBidiReactor() override = default; - void HandleLongPollingRequest(grpc::ServerUnaryReactor *reactor, - RaySyncMessages *response); + void Disconnect() override; - protected: - /// Send the message from the pending queue to the target node. - /// It'll send nothing unless there is a long-polling request. - /// TODO (iycheng): Unify the sending algorithm when we migrate to gRPC streaming - void DoSend() override; - - /// These two fields are RPC related. When the server got long-polling requests, - /// these two fields will be set so that it can be used to send message. - /// After the message being sent, these two fields will be set to be empty again. - /// When the periodical timer wake up, it'll check whether these two fields are set - /// and it'll only send data when these are set. - std::vector responses_; - std::vector unary_reactors_; + private: + void OnCancel() override; + void OnDone() override; + + /// Cleanup callback when the call ends. + const std::function cleanup_cb_; + + /// grpc callback context + grpc::CallbackServerContext *server_context_; }; -/// SyncConnection for gRPC client side. It has customized logic for sending. -class ClientSyncConnection : public NodeSyncConnection { +/// Reactor for gRPC client side. It defines the client's specific behavior for a +/// streaming call. +class RayClientBidiReactor : public RaySyncerBidiReactorBase { public: - ClientSyncConnection( + RayClientBidiReactor( + const std::string &remote_node_id, + const std::string &local_node_id, instrumented_io_context &io_context, - const std::string &node_id, - std::function)> message_processor, - std::shared_ptr channel); + std::function)> message_processor, + std::function cleanup_cb, + std::unique_ptr stub); - protected: - /// Send the message from the pending queue to the target node. - /// It'll use gRPC to send the message directly. - void DoSend() override; + ~RayClientBidiReactor() override = default; - /// Start to send long-polling request to remote nodes. - void StartLongPolling(); + void Disconnect() override; - /// Stub for this connection. - std::unique_ptr stub_; + private: + /// Callback from gRPC + void OnDone(const grpc::Status &status) override; - /// Dummy request for long-polling. - DummyRequest dummy_; + /// Cleanup callback when the call ends. + const std::function cleanup_cb_; + + /// grpc callback context + grpc::ClientContext client_context_; + + std::unique_ptr stub_; }; } // namespace syncer diff --git a/src/ray/common/ray_syncer/ray_syncer.cc b/src/ray/common/ray_syncer/ray_syncer.cc index 86afc1625146..7dd8420e78a9 100644 --- a/src/ray/common/ray_syncer/ray_syncer.cc +++ b/src/ray/common/ray_syncer/ray_syncer.cc @@ -55,9 +55,10 @@ std::optional NodeState::CreateSyncMessage(MessageType message_t bool NodeState::ConsumeSyncMessage(std::shared_ptr message) { auto ¤t = cluster_view_[message->node_id()][message->message_type()]; - RAY_LOG(DEBUG) << "ConsumeSyncMessage: " << (current ? current->version() : -1) - << " message_version: " << message->version() - << ", message_from: " << NodeID::FromBinary(message->node_id()); + RAY_LOG(DEBUG) << "ConsumeSyncMessage: local_version=" + << (current ? current->version() : -1) + << " message_version=" << message->version() + << ", message_from=" << NodeID::FromBinary(message->node_id()); // Check whether newer version of this message has been received. if (current && current->version() >= message->version()) { return false; @@ -66,180 +67,85 @@ bool NodeState::ConsumeSyncMessage(std::shared_ptr message current = message; auto receiver = receivers_[message->message_type()]; if (receiver != nullptr) { + RAY_LOG(DEBUG) << "Consume message from: " << NodeID::FromBinary(message->node_id()); receiver->ConsumeSyncMessage(message); } return true; } -NodeSyncConnection::NodeSyncConnection( - instrumented_io_context &io_context, - std::string remote_node_id, - std::function)> message_processor) - : io_context_(io_context), - remote_node_id_(std::move(remote_node_id)), - message_processor_(std::move(message_processor)) {} - -void NodeSyncConnection::ReceiveUpdate(RaySyncMessages messages) { - for (auto &message : *messages.mutable_sync_messages()) { - auto &node_versions = GetNodeComponentVersions(message.node_id()); - RAY_LOG(DEBUG) << "Receive update: " - << " message_type=" << message.message_type() - << ", message_version=" << message.version() - << ", local_message_version=" << node_versions[message.message_type()]; - if (node_versions[message.message_type()] < message.version()) { - node_versions[message.message_type()] = message.version(); - message_processor_(std::make_shared(std::move(message))); - } - } -} +namespace { -bool NodeSyncConnection::PushToSendingQueue( - std::shared_ptr message) { - // Try to filter out the messages the target node already has. - // Usually it'll be the case when the message is generated from the - // target node or it's sent from the target node. - if (message->node_id() == GetRemoteNodeID()) { - // Skip the message when it's about the node of this connection. - return false; - } - - auto &node_versions = GetNodeComponentVersions(message->node_id()); - if (node_versions[message->message_type()] < message->version()) { - node_versions[message->message_type()] = message->version(); - sending_buffer_[std::make_pair(message->node_id(), message->message_type())] = - message; - return true; - } - return false; +std::string GetNodeIDFromServerContext(grpc::CallbackServerContext *server_context) { + const auto &metadata = server_context->client_metadata(); + auto iter = metadata.find("node_id"); + RAY_CHECK(iter != metadata.end()); + return NodeID::FromHex(std::string(iter->second.begin(), iter->second.end())).Binary(); } -std::array &NodeSyncConnection::GetNodeComponentVersions( - const std::string &node_id) { - auto iter = node_versions_.find(node_id); - if (iter == node_versions_.end()) { - iter = - node_versions_.emplace(node_id, std::array()).first; - iter->second.fill(-1); - } - return iter->second; -} +} // namespace -ClientSyncConnection::ClientSyncConnection( +RayServerBidiReactor::RayServerBidiReactor( + grpc::CallbackServerContext *server_context, instrumented_io_context &io_context, - const std::string &node_id, - std::function)> message_processor, - std::shared_ptr channel) - : NodeSyncConnection(io_context, node_id, std::move(message_processor)), - stub_(ray::rpc::syncer::RaySyncer::NewStub(channel)) { - for (int64_t i = 0; i < RayConfig::instance().ray_syncer_polling_buffer(); ++i) { - StartLongPolling(); - } + const std::string &local_node_id, + std::function)> message_processor, + std::function cleanup_cb) + : RaySyncerBidiReactorBase( + io_context, + GetNodeIDFromServerContext(server_context), + std::move(message_processor)), + cleanup_cb_(std::move(cleanup_cb)), + server_context_(server_context) { + // Send the local node id to the remote + server_context_->AddInitialMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); + StartSendInitialMetadata(); + + // Start pulling from remote + StartPull(); } -void ClientSyncConnection::StartLongPolling() { - // This will be a long-polling request. The node will only reply if - // 1. there is a new version of message - // 2. and it has passed X ms since last update. - auto client_context = std::make_shared(); - auto in_message = std::make_shared(); - stub_->async()->LongPolling( - client_context.get(), - &dummy_, - in_message.get(), - [this, client_context, in_message](grpc::Status status) mutable { - if (status.ok()) { - io_context_.dispatch( - [this, messages = std::move(*in_message)]() mutable { - ReceiveUpdate(std::move(messages)); - }, - "LongPollingCallback"); - // Start the next polling. - StartLongPolling(); - } - }); +void RayServerBidiReactor::Disconnect() { + io_context_.dispatch([this]() { Finish(grpc::Status::OK); }, ""); } -void ClientSyncConnection::DoSend() { - if (sending_buffer_.empty()) { - return; - } +void RayServerBidiReactor::OnCancel() { Disconnect(); } - auto client_context = std::make_shared(); - auto arena = std::make_shared(); - auto request = google::protobuf::Arena::CreateMessage(arena.get()); - auto response = google::protobuf::Arena::CreateMessage(arena.get()); - - std::vector> holder; - - size_t message_bytes = 0; - auto iter = sending_buffer_.begin(); - while (message_bytes < RayConfig::instance().max_sync_message_batch_bytes() && - iter != sending_buffer_.end()) { - message_bytes += iter->second->sync_message().size(); - // TODO (iycheng): Use arena allocator for optimization - request->mutable_sync_messages()->UnsafeArenaAddAllocated( - const_cast(iter->second.get())); - holder.push_back(iter->second); - sending_buffer_.erase(iter++); - } - if (request->sync_messages_size() != 0) { - stub_->async()->Update( - client_context.get(), - request, - response, - [arena, client_context, holder = std::move(holder)](grpc::Status status) { - if (!status.ok()) { - RAY_LOG(ERROR) << "Sending request failed because of " - << status.error_message(); - } - }); - } +void RayServerBidiReactor::OnDone() { + io_context_.dispatch( + [this]() { + cleanup_cb_(GetRemoteNodeID(), false); + delete this; + }, + ""); } -ServerSyncConnection::ServerSyncConnection( - instrumented_io_context &io_context, +RayClientBidiReactor::RayClientBidiReactor( const std::string &remote_node_id, - std::function)> message_processor) - : NodeSyncConnection(io_context, remote_node_id, std::move(message_processor)) {} - -ServerSyncConnection::~ServerSyncConnection() { - // If there is a pending request, we need to cancel it. Otherwise, rpc will - // hang there forever. - while (!unary_reactors_.empty()) { - unary_reactors_.back()->Finish(grpc::Status::CANCELLED); - unary_reactors_.pop_back(); - } + const std::string &local_node_id, + instrumented_io_context &io_context, + std::function)> message_processor, + std::function cleanup_cb, + std::unique_ptr stub) + : RaySyncerBidiReactorBase( + io_context, remote_node_id, std::move(message_processor)), + cleanup_cb_(std::move(cleanup_cb)), + stub_(std::move(stub)) { + client_context_.AddMetadata("node_id", NodeID::FromBinary(local_node_id).Hex()); + stub_->async()->StartSync(&client_context_, this); + StartPull(); } -void ServerSyncConnection::HandleLongPollingRequest(grpc::ServerUnaryReactor *reactor, - RaySyncMessages *response) { - unary_reactors_.push_back(reactor); - responses_.push_back(response); +void RayClientBidiReactor::OnDone(const grpc::Status &status) { + io_context_.dispatch( + [this, status]() { + cleanup_cb_(GetRemoteNodeID(), !status.ok()); + delete this; + }, + ""); } -void ServerSyncConnection::DoSend() { - // There is no receive request - if (unary_reactors_.empty() || sending_buffer_.empty()) { - return; - } - - RAY_CHECK(!responses_.empty()); - - size_t message_bytes = 0; - auto iter = sending_buffer_.begin(); - while (message_bytes < RayConfig::instance().max_sync_message_batch_bytes() && - iter != sending_buffer_.end()) { - message_bytes += iter->second->sync_message().size(); - // TODO (iycheng): Use arena allocator for optimization - responses_.back()->add_sync_messages()->CopyFrom(*iter->second); - sending_buffer_.erase(iter++); - } - - if (responses_.back()->sync_messages_size() != 0) { - unary_reactors_.back()->Finish(grpc::Status::OK); - responses_.pop_back(); - unary_reactors_.pop_back(); - } +void RayClientBidiReactor::Disconnect() { + io_context_.dispatch([this]() { StartWritesDone(); }, ""); } RaySyncer::RaySyncer(instrumented_io_context &io_context, @@ -249,109 +155,133 @@ RaySyncer::RaySyncer(instrumented_io_context &io_context, node_state_(std::make_unique()), timer_(io_context) { stopped_ = std::make_shared(false); - timer_.RunFnPeriodically( - [this]() { - for (auto &[_, sync_connection] : sync_connections_) { - sync_connection->DoSend(); - } - }, - RayConfig::instance().raylet_report_resources_period_milliseconds()); } RaySyncer::~RaySyncer() { *stopped_ = true; - for (auto &call : inflight_requests_) { - auto f = call->promise.get_future(); - if (!f.valid()) { - call->context.TryCancel(); - } - f.get(); - } + io_context_.dispatch( + [reactors = sync_reactors_]() { + for (auto [_, reactor] : reactors) { + reactor->Disconnect(); + } + }, + ""); } -void RaySyncer::Connect(std::shared_ptr channel) { - auto call = std::make_unique(); - - auto stub = ray::rpc::syncer::RaySyncer::NewStub(channel); - call->request.set_node_id(local_node_id_); - - stub->async()->StartSync( - &call->context, - &call->request, - &call->response, - [this, channel, call = call.get(), stopped = this->stopped_](grpc::Status status) { - call->promise.set_value(); - if (*stopped) { - return; - } - if (status.ok()) { - io_context_.dispatch( - [this, channel, node_id = call->response.node_id()]() { - auto connection = std::make_unique( - io_context_, - node_id, - [this](auto msg) { BroadcastMessage(msg); }, - channel); - Connect(std::move(connection)); - }, - "StartSyncCallback"); +std::vector RaySyncer::GetAllConnectedNodeIDs() const { + std::promise> promise; + io_context_.dispatch( + [&]() { + std::vector nodes; + for (auto [node_id, _] : sync_reactors_) { + nodes.push_back(node_id); } - }); - inflight_requests_.emplace(std::move(call)); + promise.set_value(std::move(nodes)); + }, + ""); + return promise.get_future().get(); +} + +void RaySyncer::Connect(const std::string &node_id, + std::shared_ptr channel) { + io_context_.dispatch( + [=]() { + auto stub = ray::rpc::syncer::RaySyncer::NewStub(channel); + auto reactor = new RayClientBidiReactor( + /* remote_node_id */ node_id, + /* local_node_id */ GetLocalNodeID(), + /* io_context */ io_context_, + /* message_processor */ [this](auto msg) { BroadcastRaySyncMessage(msg); }, + /* cleanup_cb */ + [this, channel](const std::string &node_id, bool restart) { + sync_reactors_.erase(node_id); + if (restart) { + RAY_LOG(INFO) << "Connection is broken. Reconnect to node: " + << NodeID::FromBinary(node_id); + Connect(node_id, channel); + } + }, + /* stub */ std::move(stub)); + Connect(reactor); + reactor->StartCall(); + }, + ""); } -void RaySyncer::Connect(std::unique_ptr connection) { - // Somehow connection=std::move(connection) won't be compiled here. - // Potentially it might have a leak here if the function is not executed. +void RaySyncer::Connect(RaySyncerBidiReactor *reactor) { io_context_.dispatch( - [this, connection = connection.release()]() mutable { - RAY_CHECK(connection != nullptr); - RAY_CHECK(sync_connections_[connection->GetRemoteNodeID()] == nullptr); - auto &conn = *connection; - sync_connections_[connection->GetRemoteNodeID()].reset(connection); + [this, reactor]() { + RAY_CHECK(sync_reactors_.find(reactor->GetRemoteNodeID()) == + sync_reactors_.end()); + sync_reactors_[reactor->GetRemoteNodeID()] = reactor; + // Send the view for new connections. for (const auto &[_, messages] : node_state_->GetClusterView()) { - for (auto &message : messages) { + for (const auto &message : messages) { if (!message) { continue; } - conn.PushToSendingQueue(message); + RAY_LOG(DEBUG) << "Push init view from: " + << NodeID::FromBinary(GetLocalNodeID()) << " to " + << NodeID::FromBinary(reactor->GetRemoteNodeID()) << " about " + << NodeID::FromBinary(message->node_id()); + reactor->PushToSendingQueue(message); } } }, - "RaySyncer::Connect"); + "RaySyncerConnect"); } void RaySyncer::Disconnect(const std::string &node_id) { + std::promise promise; io_context_.dispatch( - [this, node_id]() { - auto iter = sync_connections_.find(node_id); - if (iter != sync_connections_.end()) { - sync_connections_.erase(iter); + [&]() { + auto iter = sync_reactors_.find(node_id); + if (iter == sync_reactors_.end()) { + promise.set_value(nullptr); + return; + } + + auto reactor = iter->second; + if (iter != sync_reactors_.end()) { + sync_reactors_.erase(iter); } + promise.set_value(reactor); }, "RaySyncerDisconnect"); + auto reactor = promise.get_future().get(); + if (reactor != nullptr) { + reactor->Disconnect(); + } } -bool RaySyncer::Register(MessageType message_type, +void RaySyncer::Register(MessageType message_type, const ReporterInterface *reporter, ReceiverInterface *receiver, int64_t pull_from_reporter_interval_ms) { - if (!node_state_->SetComponent(message_type, reporter, receiver)) { - return false; - } + io_context_.dispatch( + [this, message_type, reporter, receiver, pull_from_reporter_interval_ms]() mutable { + if (!node_state_->SetComponent(message_type, reporter, receiver)) { + return; + } - // Set job to pull from reporter periodically - if (reporter != nullptr && pull_from_reporter_interval_ms > 0) { - timer_.RunFnPeriodically( - [this, message_type]() { OnDemandBroadcasting(message_type); }, - pull_from_reporter_interval_ms); - } + // Set job to pull from reporter periodically + if (reporter != nullptr && pull_from_reporter_interval_ms > 0) { + timer_.RunFnPeriodically( + [this, stopped = stopped_, message_type]() { + if (*stopped) { + return; + } + OnDemandBroadcasting(message_type); + }, + pull_from_reporter_interval_ms); + } - RAY_LOG(DEBUG) << "Registered components: " - << "message_type:" << message_type << ", reporter:" << reporter - << ", receiver:" << receiver - << ", pull_from_reporter_interval_ms:" << pull_from_reporter_interval_ms; - return true; + RAY_LOG(DEBUG) << "Registered components: " + << "message_type:" << message_type << ", reporter:" << reporter + << ", receiver:" << receiver << ", pull_from_reporter_interval_ms:" + << pull_from_reporter_interval_ms; + }, + "RaySyncerRegister"); } bool RaySyncer::OnDemandBroadcasting(MessageType message_type) { @@ -372,90 +302,38 @@ void RaySyncer::BroadcastMessage(std::shared_ptr message) io_context_.dispatch( [this, message] { // The message is stale. Just skip this one. + RAY_LOG(DEBUG) << "Receive message from: " + << NodeID::FromBinary(message->node_id()) << " to " + << NodeID::FromBinary(GetLocalNodeID()); if (!node_state_->ConsumeSyncMessage(message)) { return; } - for (auto &connection : sync_connections_) { - connection.second->PushToSendingQueue(message); + for (auto &reactor : sync_reactors_) { + reactor.second->PushToSendingQueue(message); } }, "RaySyncer.BroadcastMessage"); } -grpc::ServerUnaryReactor *RaySyncerService::StartSync( - grpc::CallbackServerContext *context, - const StartSyncRequest *request, - StartSyncResponse *response) { - auto *reactor = context->DefaultReactor(); - // Make sure server only have one client - if (!remote_node_id_.empty()) { - RAY_LOG(WARNING) << "Get a new sync request from " - << NodeID::FromBinary(request->node_id()) << ". " - << "Now disconnect from " << NodeID::FromBinary(remote_node_id_); - syncer_.Disconnect(remote_node_id_); - } - remote_node_id_ = request->node_id(); - RAY_LOG(DEBUG) << "Get connect from: " << NodeID::FromBinary(remote_node_id_); - syncer_.GetIOContext().dispatch( - [this, response, reactor, context]() { - if (context->IsCancelled()) { - reactor->Finish(grpc::Status::CANCELLED); - return; - } - - syncer_.Connect(std::make_unique( - syncer_.GetIOContext(), remote_node_id_, [this](auto msg) { - syncer_.BroadcastMessage(msg); - })); - response->set_node_id(syncer_.GetLocalNodeID()); - reactor->Finish(grpc::Status::OK); - }, - "RaySyncer::StartSync"); - return reactor; -} - -grpc::ServerUnaryReactor *RaySyncerService::Update(grpc::CallbackServerContext *context, - const RaySyncMessages *request, - DummyResponse *) { - auto *reactor = context->DefaultReactor(); - // Make sure request is allocated from heap so that it can be moved safely. - RAY_CHECK(request->GetArena() == nullptr); - syncer_.GetIOContext().dispatch( - [this, request = std::move(*const_cast(request))]() mutable { - auto *sync_connection = dynamic_cast( - syncer_.GetSyncConnection(remote_node_id_)); - if (sync_connection != nullptr) { - sync_connection->ReceiveUpdate(std::move(request)); - } else { - RAY_LOG(FATAL) << "Fail to get the sync context"; - } - }, - "SyncerUpdate"); - reactor->Finish(grpc::Status::OK); - return reactor; -} - -grpc::ServerUnaryReactor *RaySyncerService::LongPolling( - grpc::CallbackServerContext *context, - const DummyRequest *, - RaySyncMessages *response) { - auto *reactor = context->DefaultReactor(); - syncer_.GetIOContext().dispatch( - [this, reactor, response]() mutable { - auto *sync_connection = dynamic_cast( - syncer_.GetSyncConnection(remote_node_id_)); - if (sync_connection != nullptr) { - sync_connection->HandleLongPollingRequest(reactor, response); - } else { - RAY_LOG(ERROR) << "Fail to setup long-polling"; - reactor->Finish(grpc::Status::CANCELLED); - } - }, - "SyncLongPolling"); +ServerBidiReactor *RaySyncerService::StartSync(grpc::CallbackServerContext *context) { + auto reactor = new RayServerBidiReactor( + context, + syncer_.GetIOContext(), + syncer_.GetLocalNodeID(), + [this](auto msg) mutable { syncer_.BroadcastMessage(msg); }, + [this](const std::string &node_id, bool reconnect) mutable { + // No need to reconnect for server side. + RAY_CHECK(!reconnect); + syncer_.sync_reactors_.erase(node_id); + }); + RAY_LOG(DEBUG) << "Get connection from " + << NodeID::FromBinary(reactor->GetRemoteNodeID()) << " to " + << NodeID::FromBinary(syncer_.GetLocalNodeID()); + syncer_.Connect(reactor); return reactor; } -RaySyncerService::~RaySyncerService() { syncer_.Disconnect(remote_node_id_); } +RaySyncerService::~RaySyncerService() {} } // namespace syncer } // namespace ray diff --git a/src/ray/common/ray_syncer/ray_syncer.h b/src/ray/common/ray_syncer/ray_syncer.h index 07c636b6cc29..abe835c60e0a 100644 --- a/src/ray/common/ray_syncer/ray_syncer.h +++ b/src/ray/common/ray_syncer/ray_syncer.h @@ -27,13 +27,11 @@ namespace ray { namespace syncer { -using ray::rpc::syncer::DummyRequest; -using ray::rpc::syncer::DummyResponse; using ray::rpc::syncer::MessageType; using ray::rpc::syncer::RaySyncMessage; -using ray::rpc::syncer::RaySyncMessages; -using ray::rpc::syncer::StartSyncRequest; -using ray::rpc::syncer::StartSyncResponse; + +using ServerBidiReactor = grpc::ServerBidiReactor; +using ClientBidiReactor = grpc::ClientBidiReactor; static constexpr size_t kComponentArraySize = static_cast(ray::rpc::syncer::MessageType_ARRAYSIZE); @@ -73,7 +71,7 @@ struct ReceiverInterface { // Forward declaration of internal structures class NodeState; -class NodeSyncConnection; +class RaySyncerBidiReactor; /// RaySyncer is an embedding service for component synchronization. /// All operations in this class needs to be finished GetIOContext() @@ -81,9 +79,9 @@ class NodeSyncConnection; /// RaySyncer is the control plane to make sure all connections eventually /// have the latest view of the cluster components registered. /// RaySyncer has two components: -/// 1. NodeSyncConnection: keeps track of the sending and receiving information +/// 1. RaySyncerBidiReactor: keeps track of the sending and receiving information /// and make sure not sending the information the remote node knows. -/// 2. NodeState: keeps track of the local status, similar to NodeSyncConnection, +/// 2. NodeState: keeps track of the local status, similar to RaySyncerBidiReactor, // but it's for local node. class RaySyncer { public: @@ -98,15 +96,9 @@ class RaySyncer { /// TODO (iycheng): Introduce grpc channel pool and use node_id /// for the connection. /// - /// \param connection The connection to the remote node. - void Connect(std::unique_ptr connection); - - /// Connect to a node. - /// TODO (iycheng): Introduce grpc channel pool and use node_id - /// for the connection. - /// - /// \param connection The connection to the remote node. - void Connect(std::shared_ptr channel); + /// \param node_id The id of the node connect to. + /// \param channel The gRPC channel. + void Connect(const std::string &node_id, std::shared_ptr channel); void Disconnect(const std::string &node_id); @@ -121,7 +113,7 @@ class RaySyncer { /// \param pull_from_reporter_interval_ms The frequence to pull a message. 0 means /// never pull a message in syncer. /// from reporter and push it to sending queue. - bool Register(MessageType message_type, + void Register(MessageType message_type, const ReporterInterface *reporter, ReceiverInterface *receiver, int64_t pull_from_reporter_interval_ms = 100); @@ -143,24 +135,16 @@ class RaySyncer { /// \param message The message to be broadcasted. void BroadcastRaySyncMessage(std::shared_ptr message); + std::vector GetAllConnectedNodeIDs() const; + private: + void Connect(RaySyncerBidiReactor *connection); + + std::shared_ptr stopped_; + /// Get the io_context used by RaySyncer. instrumented_io_context &GetIOContext() { return io_context_; } - /// Get the SyncConnection of a node. - /// - /// \param node_id The node id to lookup. - /// - /// \return nullptr if it doesn't exist, otherwise, the connection associated with the - /// node. - NodeSyncConnection *GetSyncConnection(const std::string &node_id) const { - auto iter = sync_connections_.find(node_id); - if (iter == sync_connections_.end()) { - return nullptr; - } - return iter->second.get(); - } - /// Function to broadcast the messages to other nodes. /// A message will be sent to a node if that node doesn't have this message. /// The message can be generated by local reporter or received by the other node. @@ -175,29 +159,14 @@ class RaySyncer { const std::string local_node_id_; /// Manage connections. Here the key is the NodeID in binary form. - absl::flat_hash_map> sync_connections_; - - /// Upward connections. These are connections initialized not by the local node. - absl::flat_hash_set upward_connections_; + absl::flat_hash_map sync_reactors_; /// The local node state std::unique_ptr node_state_; - /// Context of a rpc call. - struct StartSyncCall { - StartSyncRequest request; - StartSyncResponse response; - grpc::ClientContext context; - std::promise promise; - }; - - absl::flat_hash_set> inflight_requests_; - /// Timer is used to do broadcasting. ray::PeriodicalRunner timer_; - std::shared_ptr stopped_; - friend class RaySyncerService; /// Test purpose friend struct SyncerServerTest; @@ -209,9 +178,6 @@ class RaySyncer { FRIEND_TEST(SyncerTest, Reconnect); }; -class ClientSyncConnection; -class ServerSyncConnection; - /// RaySyncerService is a service to take care of resource synchronization /// related operations. /// Right now only raylet needs to setup this service. But in the future, @@ -223,25 +189,10 @@ class RaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { ~RaySyncerService(); - grpc::ServerUnaryReactor *StartSync(grpc::CallbackServerContext *context, - const StartSyncRequest *request, - StartSyncResponse *response) override; - - grpc::ServerUnaryReactor *Update(grpc::CallbackServerContext *context, - const RaySyncMessages *request, - DummyResponse *) override; - - grpc::ServerUnaryReactor *LongPolling(grpc::CallbackServerContext *context, - const DummyRequest *, - RaySyncMessages *response) override; + grpc::ServerBidiReactor *StartSync( + grpc::CallbackServerContext *context) override; private: - // This will be created after connection is established. - // Ideally this should be owned by RaySyncer, but since we are doing - // long-polling right now, we have to put it here so that when - // long-polling request comes, we can set it up. - std::string remote_node_id_; - // The ray syncer this RPC wrappers of. RaySyncer &syncer_; }; diff --git a/src/ray/common/test/ray_syncer_test.cc b/src/ray/common/test/ray_syncer_test.cc index 8ca793737778..ff0652d22be3 100644 --- a/src/ray/common/test/ray_syncer_test.cc +++ b/src/ray/common/test/ray_syncer_test.cc @@ -27,6 +27,7 @@ #include #include "ray/common/ray_syncer/ray_syncer.h" +#include "ray/rpc/grpc_server.h" #include "mock/ray/common/ray_syncer/ray_syncer.h" // clang-format on @@ -44,6 +45,9 @@ namespace syncer { constexpr size_t kTestComponents = 1; +using work_guard_type = + boost::asio::executor_work_guard; + RaySyncMessage MakeMessage(MessageType cid, int64_t version, const NodeID &id) { auto msg = RaySyncMessage(); msg.set_version(version); @@ -55,6 +59,7 @@ RaySyncMessage MakeMessage(MessageType cid, int64_t version, const NodeID &id) { class RaySyncerTest : public ::testing::Test { protected: void SetUp() override { + work_guard_ = std::make_unique(io_context_.get_executor()); local_versions_.fill(0); for (size_t cid = 0; cid < reporters_.size(); ++cid) { receivers_[cid] = std::make_unique(); @@ -74,10 +79,7 @@ class RaySyncerTest : public ::testing::Test { ON_CALL(*reporter, CreateSyncMessage(_, _)) .WillByDefault(WithArg<0>(Invoke(take_snapshot))); } - thread_ = std::make_unique([this]() { - boost::asio::io_context::work work(io_context_); - io_context_.run(); - }); + thread_ = std::make_unique([this]() { io_context_.run(); }); local_id_ = NodeID::FromRandom(); syncer_ = std::make_unique(io_context_, local_id_.Binary()); } @@ -95,7 +97,7 @@ class RaySyncerTest : public ::testing::Test { } void TearDown() override { - io_context_.stop(); + work_guard_->reset(); thread_->join(); } @@ -106,6 +108,7 @@ class RaySyncerTest : public ::testing::Test { nullptr}; instrumented_io_context io_context_; + std::unique_ptr work_guard_; std::unique_ptr thread_; std::unique_ptr syncer_; @@ -145,40 +148,58 @@ TEST_F(RaySyncerTest, NodeStateConsume) { ASSERT_FALSE(node_status->ConsumeSyncMessage(std::make_shared(msg))); } -TEST_F(RaySyncerTest, NodeSyncConnection) { +struct MockReactor { + void StartRead(RaySyncMessage *) { ++read_cnt; } + + void StartWrite(const RaySyncMessage *, + grpc::WriteOptions opts = grpc::WriteOptions()) { + ++write_cnt; + } + + virtual void OnWriteDone(bool ok) {} + virtual void OnReadDone(bool ok) {} + + size_t read_cnt = 0; + size_t write_cnt = 0; +}; + +TEST_F(RaySyncerTest, RaySyncerBidiReactorBase) { auto node_id = NodeID::FromRandom(); - MockNodeSyncConnection sync_connection( + MockRaySyncerBidiReactorBase sync_reactor( io_context_, node_id.Binary(), - [](std::shared_ptr) {}); + [](std::shared_ptr) {}); auto from_node_id = NodeID::FromRandom(); auto msg = MakeMessage(MessageType::RESOURCE_VIEW, 0, from_node_id); + auto msg_ptr1 = std::make_shared(msg); + msg.set_version(2); + auto msg_ptr2 = std::make_shared(msg); + msg.set_version(3); + auto msg_ptr3 = std::make_shared(msg); // First push will succeed and the second one will be deduplicated. - ASSERT_TRUE(sync_connection.PushToSendingQueue(std::make_shared(msg))); - ASSERT_FALSE(sync_connection.PushToSendingQueue(std::make_shared(msg))); - ASSERT_EQ(1, sync_connection.sending_buffer_.size()); - ASSERT_EQ(0, sync_connection.sending_buffer_.begin()->second->version()); - ASSERT_EQ(1, sync_connection.node_versions_.size()); + ASSERT_TRUE(sync_reactor.PushToSendingQueue(msg_ptr1)); + ASSERT_FALSE(sync_reactor.PushToSendingQueue(msg_ptr1)); + ASSERT_EQ(0, sync_reactor.sending_buffer_.size()); + + ASSERT_TRUE(sync_reactor.PushToSendingQueue(msg_ptr2)); + ASSERT_EQ(1, sync_reactor.sending_buffer_.size()); + ASSERT_EQ(1, sync_reactor.node_versions_.size()); + ASSERT_EQ(2, sync_reactor.sending_buffer_.begin()->second->version()); ASSERT_EQ( - 0, - sync_connection.node_versions_[from_node_id.Binary()][MessageType::RESOURCE_VIEW]); + 2, sync_reactor.node_versions_[from_node_id.Binary()][MessageType::RESOURCE_VIEW]); - msg.set_version(2); - ASSERT_TRUE(sync_connection.PushToSendingQueue(std::make_shared(msg))); - ASSERT_FALSE(sync_connection.PushToSendingQueue(std::make_shared(msg))); - // The previous message is deleted. - ASSERT_EQ(1, sync_connection.sending_buffer_.size()); - ASSERT_EQ(1, sync_connection.node_versions_.size()); - ASSERT_EQ(2, sync_connection.sending_buffer_.begin()->second->version()); + ASSERT_TRUE(sync_reactor.PushToSendingQueue(msg_ptr3)); + ASSERT_EQ(1, sync_reactor.sending_buffer_.size()); + ASSERT_EQ(1, sync_reactor.node_versions_.size()); + ASSERT_EQ(3, sync_reactor.sending_buffer_.begin()->second->version()); ASSERT_EQ( - 2, - sync_connection.node_versions_[from_node_id.Binary()][MessageType::RESOURCE_VIEW]); + 3, sync_reactor.node_versions_[from_node_id.Binary()][MessageType::RESOURCE_VIEW]); } struct SyncerServerTest { - SyncerServerTest(std::string port) { + SyncerServerTest(std::string port) : work_guard(io_context.get_executor()) { this->server_port = port; // Setup io context auto node_id = NodeID::FromRandom(); @@ -187,6 +208,7 @@ struct SyncerServerTest { } // Setup syncer and grpc server syncer = std::make_unique(io_context, node_id.Binary()); + thread = std::make_unique([this] { io_context.run(); }); auto server_address = std::string("0.0.0.0:") + port; grpc::ServerBuilder builder; @@ -196,7 +218,10 @@ struct SyncerServerTest { server = builder.BuildAndStart(); for (size_t cid = 0; cid < reporters.size(); ++cid) { - auto snapshot_received = [this](std::shared_ptr message) { + auto snapshot_received = [this, + node_id](std::shared_ptr message) { + RAY_LOG(DEBUG) << "Message received: from " + << NodeID::FromBinary(message->node_id()) << " to " << node_id; auto iter = received_versions.find(message->node_id()); if (iter == received_versions.end()) { for (auto &v : received_versions[message->node_id()]) { @@ -208,6 +233,9 @@ struct SyncerServerTest { received_versions[message->node_id()][message->message_type()] = message->version(); message_consumed[message->node_id()]++; + RAY_LOG(DEBUG) << "Message consumed from " + << NodeID::FromBinary(message->node_id()) + << ", local_id=" << node_id; }; receivers[cid] = std::make_unique(); EXPECT_CALL(*receivers[cid], ConsumeSyncMessage(_)) @@ -232,10 +260,6 @@ struct SyncerServerTest { syncer->Register( static_cast(cid), reporter.get(), receivers[cid].get()); } - thread = std::make_unique([this] { - boost::asio::io_context::work work(io_context); - io_context.run(); - }); } void WaitSendingFlush() { @@ -244,13 +268,22 @@ struct SyncerServerTest { auto f = p.get_future(); io_context.post( [&p, this]() mutable { - for (const auto &[node_id, conn] : syncer->sync_connections_) { - if (!conn->sending_buffer_.empty()) { + for (const auto &[node_id, conn] : syncer->sync_reactors_) { + auto ptr = dynamic_cast(conn); + size_t remainings = 0; + if (ptr == nullptr) { + remainings = + dynamic_cast(conn)->sending_buffer_.size(); + } else { + remainings = ptr->sending_buffer_.size(); + } + + if (remainings != 0) { p.set_value(false); RAY_LOG(INFO) << NodeID::FromBinary(syncer->GetLocalNodeID()) << ": " << "Waiting for message on " << NodeID::FromBinary(node_id) << " to be sent." - << " Remainings " << conn->sending_buffer_.size(); + << " Remainings " << remainings; return; } } @@ -281,11 +314,19 @@ struct SyncerServerTest { return false; } - ~SyncerServerTest() { - service.reset(); - server.reset(); + void Stop() { + for (auto node_id : syncer->GetAllConnectedNodeIDs()) { + syncer->Disconnect(node_id); + } + + server->Shutdown(); + io_context.stop(); thread->join(); + + server.reset(); + service.reset(); + syncer.reset(); } @@ -314,7 +355,9 @@ struct SyncerServerTest { std::unique_ptr syncer; std::unique_ptr server; std::unique_ptr thread; + instrumented_io_context io_context; + work_guard_type work_guard; std::string server_port; std::array, kTestComponents> local_versions; std::array, kTestComponents> reporters = { @@ -367,38 +410,57 @@ using TClusterView = absl::flat_hash_map< std::string, std::array, kComponentArraySize>>; -TEST(SyncerTest, Test1To1) { - auto s1 = SyncerServerTest("19990"); +class SyncerTest : public ::testing::Test { + public: + SyncerServerTest &MakeServer(std::string port) { + servers.emplace_back(std::make_unique(port)); + return *servers.back(); + } + + protected: + void TearDown() override { + // Drain all grpc requests. + for (auto &s : servers) { + s->Stop(); + } + + std::this_thread::sleep_for(1s); + } + std::vector> servers; +}; + +TEST_F(SyncerTest, Test1To1) { + auto &s1 = MakeServer("19990"); - auto s2 = SyncerServerTest("19991"); + auto &s2 = MakeServer("19991"); // Make sure the setup is correct ASSERT_NE(nullptr, s1.receivers[MessageType::RESOURCE_VIEW]); ASSERT_NE(nullptr, s2.receivers[MessageType::RESOURCE_VIEW]); ASSERT_NE(nullptr, s1.reporters[MessageType::RESOURCE_VIEW]); ASSERT_NE(nullptr, s2.reporters[MessageType::RESOURCE_VIEW]); + RAY_LOG(DEBUG) << "s1: " << NodeID::FromBinary(s1.syncer->GetLocalNodeID()); + RAY_LOG(DEBUG) << "s2: " << NodeID::FromBinary(s2.syncer->GetLocalNodeID()); auto channel_to_s2 = MakeChannel("19991"); - s1.syncer->Connect(channel_to_s2); + s1.syncer->Connect(s2.syncer->GetLocalNodeID(), channel_to_s2); // Make sure s2 adds s1 ASSERT_TRUE(s2.WaitUntil( - [&s2]() { - return s2.syncer->sync_connections_.size() == 1 && s2.snapshot_taken == 1; - }, + [&s2]() { return s2.syncer->sync_reactors_.size() == 1 && s2.snapshot_taken == 1; }, 5)); // Make sure s1 adds s2 ASSERT_TRUE(s1.WaitUntil( - [&s1]() { - return s1.syncer->sync_connections_.size() == 1 && s1.snapshot_taken == 1; - }, + [&s1]() { return s1.syncer->sync_reactors_.size() == 1 && s1.snapshot_taken == 1; }, 5)); // s1 will only send 1 message to s2 because it only has one reporter ASSERT_TRUE(s2.WaitUntil( [&s2, node_id = s1.syncer->GetLocalNodeID()]() { + RAY_LOG(DEBUG) << NodeID::FromBinary(node_id) << " - " + << s2.GetNumConsumedMessages(node_id); return s2.GetNumConsumedMessages(node_id) == 1; }, 5)); @@ -406,6 +468,9 @@ TEST(SyncerTest, Test1To1) { // s2 will send 2 messages to s1 because it has two reporters. ASSERT_TRUE(s1.WaitUntil( [&s1, node_id = s2.syncer->GetLocalNodeID()]() { + RAY_LOG(DEBUG) << "Num of messages from " << NodeID::FromBinary(node_id) << " to " + << NodeID::FromBinary(s1.syncer->GetLocalNodeID()) << " is " + << s1.GetNumConsumedMessages(node_id); return s1.GetNumConsumedMessages(node_id) == 1; }, 5)); @@ -475,7 +540,7 @@ TEST(SyncerTest, Test1To1) { ASSERT_LE(s2.GetNumConsumedMessages(s1.syncer->GetLocalNodeID()), max_sends + 3); } -TEST(SyncerTest, Reconnect) { +TEST_F(SyncerTest, Reconnect) { // This test is to check reconnect works. // Firstly // s1 -> s3 @@ -483,59 +548,47 @@ TEST(SyncerTest, Reconnect) { // s2 -> s3 // And we need to ensure s3 is connecting to s2 - auto s1 = SyncerServerTest("19990"); - auto s2 = SyncerServerTest("19991"); - auto s3 = SyncerServerTest("19992"); + auto &s1 = MakeServer("19990"); + auto &s2 = MakeServer("19991"); + auto &s3 = MakeServer("19992"); - s1.syncer->Connect(MakeChannel("19992")); + s1.syncer->Connect(s3.syncer->GetLocalNodeID(), MakeChannel("19992")); // Make sure the setup is correct ASSERT_TRUE(s1.WaitUntil( - [&s1]() { - return s1.syncer->sync_connections_.size() == 1 && s1.snapshot_taken == 1; - }, + [&s1]() { return s1.syncer->sync_reactors_.size() == 1 && s1.snapshot_taken == 1; }, 5)); ASSERT_TRUE(s1.WaitUntil( - [&s3]() { - return s3.syncer->sync_connections_.size() == 1 && s3.snapshot_taken == 1; - }, + [&s3]() { return s3.syncer->sync_reactors_.size() == 1 && s3.snapshot_taken == 1; }, 5)); - s2.syncer->Connect(MakeChannel("19992")); + s2.syncer->Connect(s3.syncer->GetLocalNodeID(), MakeChannel("19992")); ASSERT_TRUE(s1.WaitUntil( - [&s2]() { - return s2.syncer->sync_connections_.size() == 1 && s2.snapshot_taken == 1; - }, + [&s2]() { return s2.syncer->sync_reactors_.size() == 1 && s2.snapshot_taken == 1; }, 5)); } -TEST(SyncerTest, Broadcast) { +TEST_F(SyncerTest, Broadcast) { // This test covers the broadcast feature of ray syncer. - auto s1 = SyncerServerTest("19990"); - auto s2 = SyncerServerTest("19991"); - auto s3 = SyncerServerTest("19992"); + auto &s1 = MakeServer("19990"); + auto &s2 = MakeServer("19991"); + auto &s3 = MakeServer("19992"); // We need to make sure s1 is sending data to s3 for s2 - s1.syncer->Connect(MakeChannel("19991")); - s1.syncer->Connect(MakeChannel("19992")); + s1.syncer->Connect(s2.syncer->GetLocalNodeID(), MakeChannel("19991")); + s1.syncer->Connect(s3.syncer->GetLocalNodeID(), MakeChannel("19992")); // Make sure the setup is correct ASSERT_TRUE(s1.WaitUntil( - [&s1]() { - return s1.syncer->sync_connections_.size() == 2 && s1.snapshot_taken == 1; - }, + [&s1]() { return s1.syncer->sync_reactors_.size() == 2 && s1.snapshot_taken == 1; }, 5)); ASSERT_TRUE(s1.WaitUntil( - [&s2]() { - return s2.syncer->sync_connections_.size() == 1 && s2.snapshot_taken == 1; - }, + [&s2]() { return s2.syncer->sync_reactors_.size() == 1 && s2.snapshot_taken == 1; }, 5)); ASSERT_TRUE(s1.WaitUntil( - [&s3]() { - return s3.syncer->sync_connections_.size() == 1 && s3.snapshot_taken == 1; - }, + [&s3]() { return s3.syncer->sync_reactors_.size() == 1 && s3.snapshot_taken == 1; }, 5)); // Change the resource in s2 and make sure s1 && s3 are correct @@ -554,7 +607,7 @@ TEST(SyncerTest, Broadcast) { 5)); } -bool CompareViews(const std::vector> &servers, +bool CompareViews(const std::vector &servers, const std::vector &views, const std::vector> &g) { // Check broadcasting is working @@ -595,7 +648,7 @@ bool CompareViews(const std::vector> &servers, } bool TestCorrectness(std::function get_cluster_view, - std::vector> &servers, + std::vector &servers, const std::vector> &g) { auto check = [&servers, get_cluster_view, &g]() { std::vector views; @@ -656,15 +709,16 @@ bool TestCorrectness(std::function get_cluster_ return check(); } -TEST(SyncerTest, Test1ToN) { +TEST_F(SyncerTest, Test1ToN) { size_t base_port = 18990; - std::vector> servers; + std::vector servers; for (int i = 0; i < 20; ++i) { - servers.push_back(std::make_unique(std::to_string(i + base_port))); + servers.push_back(&MakeServer(std::to_string(i + base_port))); } std::vector> g(servers.size()); for (size_t i = 1; i < servers.size(); ++i) { - servers[0]->syncer->Connect(MakeChannel(servers[i]->server_port)); + servers[0]->syncer->Connect(servers[i]->syncer->GetLocalNodeID(), + MakeChannel(servers[i]->server_port)); g[0].insert(i); } @@ -680,11 +734,11 @@ TEST(SyncerTest, Test1ToN) { ASSERT_TRUE(TestCorrectness(get_cluster_view, servers, g)); } -TEST(SyncerTest, TestMToN) { +TEST_F(SyncerTest, TestMToN) { size_t base_port = 18990; - std::vector> servers; + std::vector servers; for (int i = 0; i < 20; ++i) { - servers.push_back(std::make_unique(std::to_string(i + base_port))); + servers.push_back(&MakeServer(std::to_string(i + base_port))); } std::vector> g(servers.size()); // Try to construct a tree based structure @@ -693,7 +747,8 @@ TEST(SyncerTest, TestMToN) { while (i < servers.size()) { // try to connect to 2 servers per node. for (int k = 0; k < 2 && i < servers.size(); ++k, ++i) { - servers[curr]->syncer->Connect(MakeChannel(servers[i]->server_port)); + servers[curr]->syncer->Connect(servers[i]->syncer->GetLocalNodeID(), + MakeChannel(servers[i]->server_port)); g[curr].insert(i); } ++curr; @@ -710,5 +765,166 @@ TEST(SyncerTest, TestMToN) { ASSERT_TRUE(TestCorrectness(get_cluster_view, servers, g)); } +struct MockRaySyncerService : public ray::rpc::syncer::RaySyncer::CallbackService { + MockRaySyncerService( + instrumented_io_context &_io_context, + std::function)> _message_processor, + std::function _cleanup_cb) + : message_processor(_message_processor), + cleanup_cb(_cleanup_cb), + node_id(NodeID::FromRandom()), + io_context(_io_context) {} + grpc::ServerBidiReactor *StartSync( + grpc::CallbackServerContext *context) override { + reactor = new RayServerBidiReactor( + context, io_context, node_id.Binary(), message_processor, cleanup_cb); + return reactor; + } + + std::function)> message_processor; + std::function cleanup_cb; + NodeID node_id; + instrumented_io_context &io_context; + RayServerBidiReactor *reactor = nullptr; +}; + +class SyncerReactorTest : public ::testing::Test { + protected: + void SetUp() override { + rpc_service_ = std::make_unique( + io_context_, + [this](auto msg) { server_received_message.set_value(msg); }, + [this](auto &node, bool restart) { + server_cleanup.set_value(std::make_pair(node, restart)); + }); + grpc::ServerBuilder builder; + builder.AddListeningPort("0.0.0.0:18990", grpc::InsecureServerCredentials()); + builder.RegisterService(rpc_service_.get()); + server = builder.BuildAndStart(); + + client_node_id = NodeID::FromRandom(); + cli_channel = MakeChannel("18990"); + auto cli_stub = ray::rpc::syncer::RaySyncer::NewStub(cli_channel); + cli_reactor = std::make_unique( + rpc_service_->node_id.Binary(), + client_node_id.Binary(), + io_context_, + [this](auto msg) { client_received_message.set_value(msg); }, + [this](const std::string &n, bool r) { + client_cleanup.set_value(std::make_pair(n, r)); + }, + std::move(cli_stub)) + .release(); + cli_reactor->StartCall(); + + work_guard_ = std::make_unique(io_context_.get_executor()); + thread_ = std::make_unique([this]() { io_context_.run(); }); + + auto start = steady_clock::now(); + while (duration_cast(steady_clock::now() - start).count() <= 5) { + RAY_LOG(INFO) << "Waiting: " + << duration_cast(steady_clock::now() - start).count(); + if (rpc_service_->reactor != nullptr) { + break; + }; + std::this_thread::sleep_for(1s); + } + } + + void TearDown() override { + io_context_.stop(); + thread_->join(); + } + + std::pair GetReactors() { + return std::make_pair(rpc_service_->reactor, cli_reactor); + } + + std::pair GetNodeID() { + return std::make_pair(rpc_service_->node_id.Binary(), client_node_id.Binary()); + } + + void ResetPromise() { + server_received_message = std::promise>(); + client_received_message = std::promise>(); + server_cleanup = std::promise>(); + client_cleanup = std::promise>(); + } + + instrumented_io_context io_context_; + std::unique_ptr work_guard_; + std::unique_ptr thread_; + std::unique_ptr rpc_service_; + std::unique_ptr server; + std::promise> server_received_message; + std::promise> client_received_message; + std::promise> server_cleanup; + std::promise> client_cleanup; + + grpc::ClientContext cli_context; + RayClientBidiReactor *cli_reactor; + std::shared_ptr cli_channel; + NodeID client_node_id; +}; + +TEST_F(SyncerReactorTest, TestReactor) { + auto [s, c] = GetReactors(); + auto [node_s, node_c] = GetNodeID(); + ASSERT_TRUE(s != nullptr); + ASSERT_TRUE(c != nullptr); + + auto msg_s = std::make_shared(); + msg_s->set_version(1); + msg_s->set_node_id(node_s); + + s->PushToSendingQueue(msg_s); + + auto msg_c = std::make_shared(); + msg_c->set_version(2); + msg_c->set_node_id(node_c); + + c->PushToSendingQueue(msg_c); + // Make sure sending is working + auto server_received = server_received_message.get_future().get(); + auto client_received = client_received_message.get_future().get(); + ResetPromise(); + ASSERT_EQ(server_received->version(), 2); + ASSERT_EQ(server_received->node_id(), node_c); + ASSERT_EQ(client_received->version(), 1); + ASSERT_EQ(client_received->node_id(), node_s); + + s->Disconnect(); + auto c_cleanup = client_cleanup.get_future().get(); + ASSERT_EQ(node_s, c_cleanup.first); + ASSERT_EQ(false, c_cleanup.second); +} + +TEST_F(SyncerReactorTest, TestReactorFailure) { + auto [s, c] = GetReactors(); + auto [node_s, node_c] = GetNodeID(); + ASSERT_TRUE(s != nullptr); + ASSERT_TRUE(c != nullptr); + s->Finish(grpc::Status::CANCELLED); + auto c_cleanup = client_cleanup.get_future().get(); + ASSERT_EQ(node_s, c_cleanup.first); + ASSERT_EQ(true, c_cleanup.second); +} + } // namespace syncer } // namespace ray + +int main(int argc, char **argv) { + InitShutdownRAII ray_log_shutdown_raii(ray::RayLog::StartRayLog, + ray::RayLog::ShutDownRayLog, + argv[0], + ray::RayLogLevel::INFO, + /*log_dir=*/""); + ray::RayLog::InstallFailureSignalHandler(argv[0]); + ray::RayLog::InstallTerminateHandler(); + + ::testing::InitGoogleTest(&argc, argv); + auto ret = RUN_ALL_TESTS(); + // Sleep for gRPC to gracefully shutdown. + std::this_thread::sleep_for(2s); + return ret; +} diff --git a/src/ray/common/test/syncer_service_e2e_test.cc b/src/ray/common/test/syncer_service_e2e_test.cc index f8e3999439b7..0e8c37d249c3 100644 --- a/src/ray/common/test/syncer_service_e2e_test.cc +++ b/src/ray/common/test/syncer_service_e2e_test.cc @@ -123,7 +123,7 @@ int main(int argc, char *argv[]) { channel = grpc::CreateCustomChannel( "localhost:" + leader_port, grpc::InsecureChannelCredentials(), argument); - syncer.Connect(channel); + syncer.Connect(ray::NodeID::FromRandom().Binary(), channel); } boost::asio::io_context::work work(io_context); diff --git a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc index 0d08b337e8d5..7fa71a8791cf 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_scheduler.cc @@ -444,7 +444,6 @@ void GcsActorScheduler::CreateActorOnWorker(std::shared_ptr actor, RAY_LOG(INFO) << "Start creating actor " << actor->GetActorID() << " on worker " << worker->GetWorkerID() << " at node " << actor->GetNodeID() << ", job id = " << actor->GetActorID().JobId(); - std::unique_ptr request(new rpc::PushTaskRequest()); request->set_intended_worker_id(worker->GetWorkerID().Binary()); request->mutable_task_spec()->CopyFrom( diff --git a/src/ray/gcs/gcs_server/gcs_server.cc b/src/ray/gcs/gcs_server/gcs_server.cc index 706a483c3760..5d9f98db3b83 100644 --- a/src/ray/gcs/gcs_server/gcs_server.cc +++ b/src/ray/gcs/gcs_server/gcs_server.cc @@ -50,7 +50,6 @@ GcsServer::GcsServer(const ray::gcs::GcsServerConfig &config, RayConfig::instance().gcs_server_rpc_client_thread_num()), raylet_client_pool_( std::make_shared(client_call_manager_)), - local_node_id_(NodeID::FromRandom()), pubsub_periodical_runner_(pubsub_io_service_), periodical_runner_(main_service), is_started_(false), @@ -273,7 +272,7 @@ void GcsServer::InitGcsResourceManager(const GcsInitData &gcs_init_data) { gcs_resource_manager_ = std::make_shared( main_service_, cluster_resource_scheduler_->GetClusterResourceManager(), - local_node_id_, + kGCSNodeID, cluster_task_manager_); // Initialize by gcs tables data. @@ -320,7 +319,7 @@ void GcsServer::InitGcsResourceManager(const GcsInitData &gcs_init_data) { void GcsServer::InitClusterResourceScheduler() { cluster_resource_scheduler_ = std::make_shared( - scheduling::NodeID(local_node_id_.Binary()), + scheduling::NodeID(kGCSNodeID.Binary()), NodeResources(), /*is_node_available_fn=*/ [](auto) { return true; }, @@ -330,7 +329,7 @@ void GcsServer::InitClusterResourceScheduler() { void GcsServer::InitClusterTaskManager() { RAY_CHECK(cluster_resource_scheduler_); cluster_task_manager_ = std::make_shared( - local_node_id_, + kGCSNodeID, cluster_resource_scheduler_, /*get_node_info=*/ [this](const NodeID &node_id) { @@ -472,8 +471,8 @@ std::string GcsServer::StorageType() const { void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { if (RayConfig::instance().use_ray_syncer()) { - ray_syncer_ = std::make_unique(ray_syncer_io_context_, - local_node_id_.Binary()); + ray_syncer_ = + std::make_unique(ray_syncer_io_context_, kGCSNodeID.Binary()); ray_syncer_->Register( syncer::MessageType::RESOURCE_VIEW, nullptr, gcs_resource_manager_.get()); ray_syncer_->Register( @@ -482,19 +481,8 @@ void GcsServer::InitRaySyncer(const GcsInitData &gcs_init_data) { boost::asio::io_service::work work(ray_syncer_io_context_); ray_syncer_io_context_.run(); }); - - for (const auto &pair : gcs_init_data.Nodes()) { - if (pair.second.state() == - rpc::GcsNodeInfo_GcsNodeState::GcsNodeInfo_GcsNodeState_ALIVE) { - rpc::Address address; - address.set_raylet_id(pair.second.node_id()); - address.set_ip_address(pair.second.node_manager_address()); - address.set_port(pair.second.node_manager_port()); - - auto raylet_client = raylet_client_pool_->GetOrConnectByAddress(address); - ray_syncer_->Connect(raylet_client->GetChannel()); - } - } + ray_syncer_service_ = std::make_unique(*ray_syncer_); + rpc_server_.RegisterService(*ray_syncer_service_); } else { /* The current synchronization flow is: @@ -622,9 +610,7 @@ void GcsServer::InstallEventListeners() { } cluster_task_manager_->ScheduleAndDispatchTasks(); - if (RayConfig::instance().use_ray_syncer()) { - ray_syncer_->Connect(raylet_client->GetChannel()); - } else { + if (!RayConfig::instance().use_ray_syncer()) { gcs_ray_syncer_->AddNode(*node); } }); @@ -640,9 +626,7 @@ void GcsServer::InstallEventListeners() { raylet_client_pool_->Disconnect(node_id); gcs_healthcheck_manager_->RemoveNode(node_id); - if (RayConfig::instance().use_ray_syncer()) { - ray_syncer_->Disconnect(node_id.Binary()); - } else { + if (!RayConfig::instance().use_ray_syncer()) { gcs_ray_syncer_->RemoveNode(*node); } }); @@ -776,14 +760,14 @@ void GcsServer::TryGlobalGC() { if (RayConfig::instance().use_ray_syncer()) { auto msg = std::make_shared(); msg->set_version(absl::GetCurrentTimeNanos()); - msg->set_node_id(local_node_id_.Binary()); + msg->set_node_id(kGCSNodeID.Binary()); msg->set_message_type(syncer::MessageType::COMMANDS); std::string serialized_msg; RAY_CHECK(resources_data.SerializeToString(&serialized_msg)); msg->set_sync_message(std::move(serialized_msg)); ray_syncer_->BroadcastRaySyncMessage(std::move(msg)); } else { - resources_data.set_node_id(local_node_id_.Binary()); + resources_data.set_node_id(kGCSNodeID.Binary()); gcs_ray_syncer_->Update(resources_data); } diff --git a/src/ray/gcs/gcs_server/gcs_server.h b/src/ray/gcs/gcs_server/gcs_server.h index f326d954bfea..f6397452aceb 100644 --- a/src/ray/gcs/gcs_server/gcs_server.h +++ b/src/ray/gcs/gcs_server/gcs_server.h @@ -188,9 +188,6 @@ class GcsServer { std::shared_ptr raylet_client_pool_; /// The gcs resource manager. std::shared_ptr gcs_resource_manager_; - /// The gcs server's node id, for the creation of `cluster_resource_scheduler_` and - /// `cluster_task_manager_`. - NodeID local_node_id_; /// The cluster resource scheduler. std::shared_ptr cluster_resource_scheduler_; /// The cluster task manager. @@ -226,6 +223,7 @@ class GcsServer { /// Ray Syncer realted fields. std::unique_ptr ray_syncer_; + std::unique_ptr ray_syncer_service_; std::unique_ptr ray_syncer_thread_; instrumented_io_context ray_syncer_io_context_; diff --git a/src/ray/protobuf/ray_syncer.proto b/src/ray/protobuf/ray_syncer.proto index f171665ef6f0..b24a0ecb323b 100644 --- a/src/ray/protobuf/ray_syncer.proto +++ b/src/ray/protobuf/ray_syncer.proto @@ -32,37 +32,6 @@ message RaySyncMessage { bytes node_id = 4; } -message RaySyncMessages { - // The bached messages. - repeated RaySyncMessage sync_messages = 1; -} - -message StartSyncRequest { - bytes node_id = 1; -} - -message StartSyncResponse { - bytes node_id = 1; -} - -message DummyRequest {} -message DummyResponse {} - service RaySyncer { - // Ideally these should be a streaming API like this - // rpc StartSync(stream RaySyncMessages) returns (stream RaySyncMessages); - // But to make sure it's the same as the current protocol, we still use - // unary rpc. - // TODO (iycheng): Using grpc streaming for the protocol. - - // This is the first message that should be sent. It will initialize - // some structure between nodes. - rpc StartSync(StartSyncRequest) returns (StartSyncResponse); - - // These two RPCs are for messages reporting and broadcasting. - // Update is used by the client to send update request to the server. - rpc Update(RaySyncMessages) returns (DummyResponse); - - // LongPolling is used by the server to send request to the client. - rpc LongPolling(DummyRequest) returns (RaySyncMessages); + rpc StartSync(stream RaySyncMessage) returns (stream RaySyncMessage); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 5ca8c5b41589..3074dbb11f83 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -527,6 +527,8 @@ ray::Status NodeManager::RegisterGcs() { /* receiver */ this, /* pull_from_reporter_interval_ms */ 0); + auto gcs_channel = gcs_client_->GetGcsRpcClient().GetChannel(); + ray_syncer_.Connect(kGCSNodeID.Binary(), gcs_channel); periodical_runner_.RunFnPeriodically( [this] { auto triggered_by_global_gc = TryLocalGC(); diff --git a/src/ray/rpc/grpc_client.h b/src/ray/rpc/grpc_client.h index e5c5065030aa..8c67353e2cb1 100644 --- a/src/ray/rpc/grpc_client.h +++ b/src/ray/rpc/grpc_client.h @@ -57,7 +57,8 @@ inline std::shared_ptr BuildChannel( ::RayConfig::instance().grpc_enable_http_proxy() ? 1 : 0); arguments->SetMaxSendMessageSize(::RayConfig::instance().max_grpc_message_size()); arguments->SetMaxReceiveMessageSize(::RayConfig::instance().max_grpc_message_size()); - + arguments->SetInt(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE, + ::RayConfig::instance().grpc_stream_buffer_size()); std::shared_ptr channel; if (::RayConfig::instance().USE_TLS()) { std::string server_cert_file = std::string(::RayConfig::instance().TLS_SERVER_CERT()); diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 962017aca9f7..d38e02d57151 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -83,7 +83,8 @@ void GrpcServer::Run() { builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, RayConfig::instance().grpc_keepalive_timeout_ms()); builder.AddChannelArgument(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 0); - + builder.AddChannelArgument(GRPC_ARG_HTTP2_WRITE_BUFFER_SIZE, + RayConfig::instance().grpc_stream_buffer_size()); // NOTE(rickyyx): This argument changes how frequent the gRPC server expects a keepalive // ping from the client. See https://github.com/grpc/grpc/blob/HEAD/doc/keepalive.md#faq // We set this to 1min because GCS gRPC client currently sends keepalive every 1min: