From 8b60b0838e63a441e73fd5b7858078ccc2c1b20f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 11 May 2023 12:12:52 -0700 Subject: [PATCH] Bring back "[Core] Port GcsPublisher to Cython" (#34393) (#35179) I spent quite a bit of time debugging the test failure in #34393 (see also #35108) It turns out the PR slightly made the _do_importing race condition (first time call in the import thread) more likely to happen. There is already a plan / PR to get rid of it (#30895) but it is currently waiting for having a replacement mechanism that @rkooo567 is working on. I synced with @scv119 and for the time being, we are planning to skip the offending test on Windows and once we got rid of the import thread, we can re-activate it. Signed-off-by: Marcus Zhang --- .buildkite/pipeline.build.yml | 6 ++ dashboard/agent.py | 6 +- dashboard/dashboard.py | 3 +- python/ray/_private/gcs_pubsub.py | 45 ---------- python/ray/_private/log_monitor.py | 8 +- python/ray/_private/utils.py | 29 +------ python/ray/_private/worker.py | 3 +- python/ray/_raylet.pyx | 63 ++++++++++++++ python/ray/autoscaler/_private/monitor.py | 3 +- python/ray/includes/common.pxd | 38 +++++++++ python/ray/includes/common.pxi | 1 + python/ray/tests/test_basic_5.py | 3 + python/ray/tests/test_failure.py | 3 +- python/ray/tests/test_gcs_fault_tolerance.py | 14 ++-- python/ray/tests/test_gcs_pubsub.py | 26 +++--- src/ray/gcs/gcs_client/gcs_client.cc | 5 +- src/ray/gcs/pubsub/gcs_pub_sub.cc | 87 ++++++++++++++++++++ src/ray/gcs/pubsub/gcs_pub_sub.h | 37 +++++++++ 18 files changed, 271 insertions(+), 109 deletions(-) diff --git a/.buildkite/pipeline.build.yml b/.buildkite/pipeline.build.yml index 81debe8a17bf..8bdd31723559 100644 --- a/.buildkite/pipeline.build.yml +++ b/.buildkite/pipeline.build.yml @@ -368,6 +368,9 @@ - DL=1 ./ci/env/install-dependencies.sh - bash ./ci/ci.sh prepare_docker - ./ci/env/env_info.sh + # This is needed or else the Ray Client tests run into a gRPC forking problem + # similar to https://github.com/grpc/grpc/issues/31885 + - pip install pip install grpcio==1.50.0 - bazel test --config=ci $(./ci/run/bazel_export_options) --test_tag_filters=client_tests,small_size_python_tests -- python/ray/tests/... @@ -418,6 +421,9 @@ - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT - DL=1 ./ci/env/install-dependencies.sh - ./ci/env/env_info.sh + # This is needed or else the Ray Client tests run into a gRPC forking problem + # similar to https://github.com/grpc/grpc/issues/31885 + - pip install pip install grpcio==1.50.0 - bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=client_tests,small_size_python_tests --test_env=TEST_EXTERNAL_REDIS=1 diff --git a/dashboard/agent.py b/dashboard/agent.py index 345099ff7c25..df57590ff0b6 100644 --- a/dashboard/agent.py +++ b/dashboard/agent.py @@ -15,7 +15,7 @@ import ray.dashboard.consts as dashboard_consts import ray.dashboard.utils as dashboard_utils from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD -from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher +from ray._private.gcs_pubsub import GcsAioPublisher from ray._raylet import GcsClient from ray._private.gcs_utils import GcsAioClient from ray._private.ray_logging import setup_component_logger @@ -263,7 +263,9 @@ async def _check_parent(): ray._private.utils.publish_error_to_driver( ray_constants.RAYLET_DIED_ERROR, msg, - gcs_publisher=GcsPublisher(address=self.gcs_address), + gcs_publisher=ray._raylet.GcsPublisher( + address=self.gcs_address + ), ) else: logger.info(msg) diff --git a/dashboard/dashboard.py b/dashboard/dashboard.py index 4732e96d23ee..273fbc4c904d 100644 --- a/dashboard/dashboard.py +++ b/dashboard/dashboard.py @@ -13,7 +13,6 @@ import ray.dashboard.consts as dashboard_consts import ray.dashboard.head as dashboard_head import ray.dashboard.utils as dashboard_utils -from ray._private.gcs_pubsub import GcsPublisher from ray._private.ray_logging import setup_component_logger from typing import Optional, Set @@ -261,7 +260,7 @@ def sigterm_handler(): raise e # Something went wrong, so push an error to all drivers. - gcs_publisher = GcsPublisher(address=args.gcs_address) + gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address) ray._private.utils.publish_error_to_driver( ray_constants.DASHBOARD_DIED_ERROR, message, diff --git a/python/ray/_private/gcs_pubsub.py b/python/ray/_private/gcs_pubsub.py index c1d39e728b15..2168b9dfed9d 100644 --- a/python/ray/_private/gcs_pubsub.py +++ b/python/ray/_private/gcs_pubsub.py @@ -4,10 +4,8 @@ import random import threading from typing import Optional, Tuple, List -import time import grpc -from grpc._channel import _InactiveRpcError from ray._private.utils import get_or_create_event_loop try: @@ -160,49 +158,6 @@ def _pop_actors(queue, batch_size=100): return msgs -class GcsPublisher(_PublisherBase): - """Publisher to GCS.""" - - def __init__(self, address: str): - channel = gcs_utils.create_gcs_channel(address) - self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel) - - def publish_error( - self, key_id: bytes, error_info: ErrorTableData, num_retries=None - ) -> None: - """Publishes error info to GCS.""" - msg = pubsub_pb2.PubMessage( - channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL, - key_id=key_id, - error_info_message=error_info, - ) - req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg]) - self._gcs_publish(req, num_retries, timeout=1) - - def publish_logs(self, log_batch: dict) -> None: - """Publishes logs to GCS.""" - req = self._create_log_request(log_batch) - self._gcs_publish(req) - - def publish_function_key(self, key: bytes) -> None: - """Publishes function key to GCS.""" - req = self._create_function_key_request(key) - self._gcs_publish(req) - - def _gcs_publish(self, req, num_retries=None, timeout=None) -> None: - count = num_retries or MAX_GCS_PUBLISH_RETRIES - while count > 0: - try: - self._stub.GcsPublish(req, timeout=timeout) - return - except _InactiveRpcError: - pass - count -= 1 - if count > 0: - time.sleep(1) - raise TimeoutError(f"Failed to publish after retries: {req}") - - class _SyncSubscriber(_SubscriberBase): def __init__( self, diff --git a/python/ray/_private/log_monitor.py b/python/ray/_private/log_monitor.py index 7f06343625ae..444ac5b34bec 100644 --- a/python/ray/_private/log_monitor.py +++ b/python/ray/_private/log_monitor.py @@ -11,11 +11,9 @@ import traceback from typing import Callable, List, Set -import ray._private.gcs_pubsub as gcs_pubsub import ray._private.ray_constants as ray_constants import ray._private.services as services import ray._private.utils -from ray._private.gcs_pubsub import GcsPublisher from ray._private.ray_logging import setup_component_logger # Logger for this module. It should be configured at the entry point @@ -135,7 +133,7 @@ class LogMonitor: def __init__( self, logs_dir, - gcs_publisher: gcs_pubsub.GcsPublisher, + gcs_publisher, is_proc_alive_fn: Callable[[int], bool], max_files_open: int = ray_constants.LOG_MONITOR_MAX_OPEN_FILES, ): @@ -525,14 +523,14 @@ def is_proc_alive(pid): ) log_monitor = LogMonitor( - args.logs_dir, gcs_pubsub.GcsPublisher(address=args.gcs_address), is_proc_alive + args.logs_dir, ray._raylet.GcsPublisher(address=args.gcs_address), is_proc_alive ) try: log_monitor.run() except Exception as e: # Something went wrong, so push an error to all drivers. - gcs_publisher = GcsPublisher(address=args.gcs_address) + gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address) traceback_str = ray._private.utils.format_error_message(traceback.format_exc()) message = ( f"The log monitor on node {platform.node()} " diff --git a/python/ray/_private/utils.py b/python/ray/_private/utils.py index 6174890cd8ea..8d1793114ac9 100644 --- a/python/ray/_private/utils.py +++ b/python/ray/_private/utils.py @@ -44,7 +44,6 @@ import ray import ray._private.ray_constants as ray_constants from ray._private.tls_utils import load_certs_from_env -from ray.core.generated.gcs_pb2 import ErrorTableData from ray.core.generated.runtime_env_common_pb2 import ( RuntimeEnvInfo as ProtoRuntimeEnvInfo, ) @@ -182,27 +181,6 @@ def push_error_to_driver( worker.core_worker.push_error(job_id, error_type, message, time.time()) -def construct_error_message(job_id, error_type, message, timestamp): - """Construct an ErrorTableData object. - - Args: - job_id: The ID of the job that the error should go to. If this is - nil, then the error will go to all drivers. - error_type: The type of the error. - message: The error message. - timestamp: The time of the error. - - Returns: - The ErrorTableData object. - """ - data = ErrorTableData() - data.job_id = job_id.binary() - data.type = error_type - data.error_message = message - data.timestamp = timestamp - return data - - def publish_error_to_driver( error_type: str, message: str, @@ -228,11 +206,12 @@ def publish_error_to_driver( if job_id is None: job_id = ray.JobID.nil() assert isinstance(job_id, ray.JobID) - error_data = construct_error_message(job_id, error_type, message, time.time()) try: - gcs_publisher.publish_error(job_id.hex().encode(), error_data, num_retries) + gcs_publisher.publish_error( + job_id.hex().encode(), error_type, message, job_id, num_retries + ) except Exception: - logger.exception(f"Failed to publish error {error_data}") + logger.exception(f"Failed to publish error: {message} [type {error_type}]") def decode(byte_str: str, allow_none: bool = False, encode_type: str = "utf-8"): diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index a9b81d672fb3..234f79c3be5c 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -68,7 +68,6 @@ GcsErrorSubscriber, GcsFunctionKeySubscriber, GcsLogSubscriber, - GcsPublisher, ) from ray._private.inspect_util import is_cython from ray._private.ray_logging import ( @@ -2074,7 +2073,7 @@ def connect( ray._private.state.state._initialize_global_state( ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address) ) - worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address) + worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address) # Initialize some fields. if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE): # We should not specify the job_id if it's `WORKER_MODE`. diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 5b135b35d419..3c929ade46ef 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -61,14 +61,17 @@ from ray.includes.common cimport ( CObjectReference, CRayObject, CRayStatus, + CErrorTableData, CGcsClientOptions, CGcsNodeInfo, CJobTableData, + CLogBatch, CTaskArg, CTaskArgByReference, CTaskArgByValue, CTaskType, CPlacementStrategy, + CPythonFunction, CSchedulingStrategy, CPlacementGroupSchedulingStrategy, CNodeAffinitySchedulingStrategy, @@ -1742,6 +1745,66 @@ cdef class GcsClient: } return result +cdef class GcsPublisher: + """Cython wrapper class of C++ `ray::gcs::PythonGcsPublisher`.""" + cdef: + shared_ptr[CPythonGcsPublisher] inner + + def __cinit__(self, address): + self.inner.reset(new CPythonGcsPublisher(address)) + check_status(self.inner.get().Connect()) + + def publish_error(self, key_id: bytes, error_type: str, message: str, + job_id=None, num_retries=None): + cdef: + CErrorTableData error_info + int64_t c_num_retries = num_retries if num_retries else -1 + c_string c_key_id = key_id + + if job_id is None: + job_id = ray.JobID.nil() + assert isinstance(job_id, ray.JobID) + error_info.set_job_id(job_id.binary()) + error_info.set_type(error_type) + error_info.set_error_message(message) + error_info.set_timestamp(time.time()) + + with nogil: + check_status( + self.inner.get().PublishError(c_key_id, error_info, c_num_retries)) + + def publish_logs(self, log_json: dict): + cdef: + CLogBatch log_batch + c_string c_job_id + + job_id = log_json.get("job") + log_batch.set_ip(log_json.get("ip") if log_json.get("ip") else b"") + log_batch.set_pid( + str(log_json.get("pid")).encode() if log_json.get("pid") else b"") + log_batch.set_job_id(job_id.encode() if job_id else b"") + log_batch.set_is_error(bool(log_json.get("is_err"))) + for line in log_json.get("lines", []): + log_batch.add_lines(line) + actor_name = log_json.get("actor_name") + log_batch.set_actor_name(actor_name.encode() if actor_name else b"") + task_name = log_json.get("task_name") + log_batch.set_task_name(task_name.encode() if task_name else b"") + + c_job_id = job_id.encode() if job_id else b"" + with nogil: + check_status(self.inner.get().PublishLogs(c_job_id, log_batch)) + + def publish_function_key(self, key: bytes): + cdef: + CPythonFunction python_function + + python_function.set_key(key) + + with nogil: + check_status(self.inner.get().PublishFunctionKey(python_function)) + + cdef class CoreWorker: def __cinit__(self, worker_type, store_socket, raylet_socket, diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 14faf14fa8e9..f15e109fc9d4 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -16,7 +16,6 @@ import ray._private.ray_constants as ray_constants import ray._private.utils from ray._private.event.event_logger import get_event_logger -from ray._private.gcs_pubsub import GcsPublisher from ray._private.ray_logging import setup_component_logger from ray._raylet import GcsClient from ray.autoscaler._private.autoscaler import StandardAutoscaler @@ -560,7 +559,7 @@ def _handle_failure(self, error): _internal_kv_put( ray_constants.DEBUG_AUTOSCALING_ERROR, message, overwrite=True ) - gcs_publisher = GcsPublisher(address=self.gcs_address) + gcs_publisher = ray._raylet.GcsPublisher(address=self.gcs_address) from ray._private.utils import publish_error_to_driver publish_error_to_driver( diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index e0f8b8ee9712..4250470f3013 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -346,6 +346,21 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil: unordered_map[c_string, double] PythonGetResourcesTotal( const CGcsNodeInfo& node_info) +cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" nogil: + + cdef cppclass CPythonGcsPublisher "ray::gcs::PythonGcsPublisher": + + CPythonGcsPublisher(const c_string& gcs_address) + + CRayStatus Connect() + + CRayStatus PublishError( + const c_string &key_id, const CErrorTableData &data, int64_t num_retries) + + CRayStatus PublishLogs(const c_string &key_id, const CLogBatch &data) + + CRayStatus PublishFunctionKey(const CPythonFunction& python_function) + cdef extern from "src/ray/protobuf/gcs.pb.h" nogil: cdef cppclass CJobConfig "ray::rpc::JobConfig": c_string ray_namespace() const @@ -372,6 +387,29 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil: c_bool is_dead() const CJobConfig config() const + cdef cppclass CPythonFunction "ray::rpc::PythonFunction": + void set_key(const c_string &key) + + cdef cppclass CErrorTableData "ray::rpc::ErrorTableData": + c_string job_id() const + c_string type() const + c_string error_message() const + double timestamp() const + + void set_job_id(const c_string &job_id) + void set_type(const c_string &type) + void set_error_message(const c_string &error_message) + void set_timestamp(double timestamp) + + cdef cppclass CLogBatch "ray::rpc::LogBatch": + void set_ip(const c_string &ip) + void set_pid(const c_string &pid) + void set_job_id(const c_string &job_id) + void set_is_error(c_bool is_error) + void add_lines(const c_string &line) + void set_actor_name(const c_string &actor_name) + void set_task_name(const c_string &task_name) + cdef extern from "ray/common/task/task_spec.h" nogil: cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup": diff --git a/python/ray/includes/common.pxi b/python/ray/includes/common.pxi index 89983ff8808c..d7c3c121bc69 100644 --- a/python/ray/includes/common.pxi +++ b/python/ray/includes/common.pxi @@ -6,6 +6,7 @@ from ray.includes.common cimport ( CObjectLocation, CGcsClientOptions, CPythonGcsClient, + CPythonGcsPublisher, ) diff --git a/python/ray/tests/test_basic_5.py b/python/ray/tests/test_basic_5.py index ffdeb6cf20b5..c3847ad8b1be 100644 --- a/python/ray/tests/test_basic_5.py +++ b/python/ray/tests/test_basic_5.py @@ -227,6 +227,9 @@ def sys_path(): subprocess.check_call(["python", "-m", "package.module2"]) +# This will be fixed on Windows once the import thread is removed, see +# https://github.com/ray-project/ray/pull/30895 +@pytest.mark.skipif(sys.platform == "win32", reason="Currently fails on Windows.") def test_worker_kv_calls(monkeypatch, shutdown_only): monkeypatch.setenv("TEST_RAY_COLLECT_KV_FREQUENCY", "1") ray.init() diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 71bb7a98dd9a..93f1c734ee0a 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -10,7 +10,6 @@ import ray._private.gcs_utils as gcs_utils import ray._private.ray_constants as ray_constants import ray._private.utils -from ray._private.gcs_pubsub import GcsPublisher from ray._private.test_utils import ( SignalActor, convert_actor_state, @@ -69,7 +68,7 @@ def interceptor(e): def test_publish_error_to_driver(ray_start_regular, error_pubsub): address_info = ray_start_regular - gcs_publisher = GcsPublisher(address=address_info["gcs_address"]) + gcs_publisher = ray._raylet.GcsPublisher(address=address_info["gcs_address"]) error_message = "Test error message" ray._private.utils.publish_error_to_driver( diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index fedd531d6cb8..72caad2f0f6e 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -18,10 +18,8 @@ run_string_as_driver, ) from ray._private.gcs_pubsub import ( - GcsPublisher, GcsErrorSubscriber, ) -from ray.core.generated.gcs_pb2 import ErrorTableData import psutil @@ -675,20 +673,20 @@ def test_publish_and_subscribe_error_info(ray_start_regular_with_external_redis) subscriber = GcsErrorSubscriber(address=gcs_server_addr) subscriber.subscribe() - publisher = GcsPublisher(address=gcs_server_addr) - err1 = ErrorTableData(error_message="test error message 1") - err2 = ErrorTableData(error_message="test error message 2") + publisher = ray._raylet.GcsPublisher(address=gcs_server_addr) print("sending error message 1") - publisher.publish_error(b"aaa_id", err1) + publisher.publish_error(b"aaa_id", "", "test error message 1") ray._private.worker._global_node.kill_gcs_server() ray._private.worker._global_node.start_gcs_server() print("sending error message 2") - publisher.publish_error(b"bbb_id", err2) + publisher.publish_error(b"bbb_id", "", "test error message 2") print("done") - assert subscriber.poll() == (b"bbb_id", err2) + (key_id, err) = subscriber.poll() + assert key_id == b"bbb_id" + assert err.error_message == "test error message 2" subscriber.close() diff --git a/python/ray/tests/test_gcs_pubsub.py b/python/ray/tests/test_gcs_pubsub.py index b9a4eddee7a4..71d4ae802f26 100644 --- a/python/ray/tests/test_gcs_pubsub.py +++ b/python/ray/tests/test_gcs_pubsub.py @@ -3,8 +3,8 @@ import threading import re +import ray from ray._private.gcs_pubsub import ( - GcsPublisher, GcsErrorSubscriber, GcsLogSubscriber, GcsFunctionKeySubscriber, @@ -24,14 +24,16 @@ def test_publish_and_subscribe_error_info(ray_start_regular): subscriber = GcsErrorSubscriber(address=gcs_server_addr) subscriber.subscribe() - publisher = GcsPublisher(address=gcs_server_addr) - err1 = ErrorTableData(error_message="test error message 1") - err2 = ErrorTableData(error_message="test error message 2") - publisher.publish_error(b"aaa_id", err1) - publisher.publish_error(b"bbb_id", err2) + publisher = ray._raylet.GcsPublisher(address=gcs_server_addr) + publisher.publish_error(b"aaa_id", "", "test error message 1") + publisher.publish_error(b"bbb_id", "", "test error message 2") - assert subscriber.poll() == (b"aaa_id", err1) - assert subscriber.poll() == (b"bbb_id", err2) + (key_id1, err1) = subscriber.poll() + assert key_id1 == b"aaa_id" + assert err1.error_message == "test error message 1" + (key_id2, err2) = subscriber.poll() + assert key_id2 == b"bbb_id" + assert err2.error_message == "test error message 2" subscriber.close() @@ -63,7 +65,7 @@ def test_publish_and_subscribe_logs(ray_start_regular): subscriber = GcsLogSubscriber(address=gcs_server_addr) subscriber.subscribe() - publisher = GcsPublisher(address=gcs_server_addr) + publisher = ray._raylet.GcsPublisher(address=gcs_server_addr) log_batch = { "ip": "127.0.0.1", "pid": 1234, @@ -114,7 +116,7 @@ def test_publish_and_subscribe_function_keys(ray_start_regular): subscriber = GcsFunctionKeySubscriber(address=gcs_server_addr) subscriber.subscribe() - publisher = GcsPublisher(address=gcs_server_addr) + publisher = ray._raylet.GcsPublisher(address=gcs_server_addr) publisher.publish_function_key(b"111") publisher.publish_function_key(b"222") @@ -196,9 +198,9 @@ def receive_logs(): t2 = threading.Thread(target=receive_logs) t2.start() - publisher = GcsPublisher(address=gcs_server_addr) + publisher = ray._raylet.GcsPublisher(address=gcs_server_addr) for i in range(0, num_messages): - publisher.publish_error(b"msg_id", ErrorTableData(error_message=f"error {i}")) + publisher.publish_error(b"msg_id", "", f"error {i}") publisher.publish_logs( { "ip": "127.0.0.1", diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index fb721893d7ea..ae342b05eec0 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -146,10 +146,7 @@ std::pair GcsClient::GetGcsServerAddress() const { PythonGcsClient::PythonGcsClient(const GcsClientOptions &options) : options_(options) {} Status PythonGcsClient::Connect() { - grpc::ChannelArguments arguments; - arguments.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, 512 * 1024 * 1024); - arguments.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 60 * 1000); - arguments.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000); + auto arguments = PythonGrpcChannelArguments(); channel_ = rpc::BuildChannel(options_.gcs_address_, options_.gcs_port_, arguments); kv_stub_ = rpc::InternalKVGcsService::NewStub(channel_); runtime_env_stub_ = rpc::RuntimeEnvGcsService::NewStub(channel_); diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.cc b/src/ray/gcs/pubsub/gcs_pub_sub.cc index 32c0e9f41367..b03a9157da46 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.cc +++ b/src/ray/gcs/pubsub/gcs_pub_sub.cc @@ -15,6 +15,7 @@ #include "ray/gcs/pubsub/gcs_pub_sub.h" #include "absl/strings/str_cat.h" +#include "ray/rpc/grpc_client.h" namespace ray { namespace gcs { @@ -212,5 +213,91 @@ Status GcsSubscriber::SubscribeAllWorkerFailures( return Status::OK(); } +grpc::ChannelArguments PythonGrpcChannelArguments() { + grpc::ChannelArguments arguments; + arguments.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, 512 * 1024 * 1024); + arguments.SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, 60 * 1000); + arguments.SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, 60 * 1000); + return arguments; +} + +PythonGcsPublisher::PythonGcsPublisher(const std::string &gcs_address) { + std::vector address = absl::StrSplit(gcs_address, ':'); + RAY_LOG(DEBUG) << "Connect to gcs server via address: " << gcs_address; + RAY_CHECK(address.size() == 2); + gcs_address_ = address[0]; + gcs_port_ = std::stoi(address[1]); +} + +Status PythonGcsPublisher::Connect() { + auto arguments = PythonGrpcChannelArguments(); + channel_ = rpc::BuildChannel(gcs_address_, gcs_port_, arguments); + pubsub_stub_ = rpc::InternalPubSubGcsService::NewStub(channel_); + return Status::OK(); +} + +constexpr int MAX_GCS_PUBLISH_RETRIES = 60; + +Status PythonGcsPublisher::DoPublishWithRetries(const rpc::GcsPublishRequest &request, + int64_t num_retries, + int64_t timeout_ms) { + int count = num_retries == -1 ? MAX_GCS_PUBLISH_RETRIES : num_retries; + rpc::GcsPublishReply reply; + grpc::Status status; + while (count > 0) { + grpc::ClientContext context; + if (timeout_ms != -1) { + context.set_deadline(std::chrono::system_clock::now() + + std::chrono::milliseconds(timeout_ms)); + } + status = pubsub_stub_->GcsPublish(&context, request, &reply); + if (status.error_code() == grpc::StatusCode::OK) { + if (reply.status().code() != static_cast(StatusCode::OK)) { + return Status::Invalid(reply.status().message()); + } + return Status::OK(); + } else if (status.error_code() == grpc::StatusCode::UNAVAILABLE || + status.error_code() == grpc::StatusCode::UNKNOWN) { + // This is the case in which we will retry + count -= 1; + std::this_thread::sleep_for(std::chrono::seconds(1)); + continue; + } else { + return Status::Invalid(status.error_message()); + } + } + return Status::TimedOut("Failed to publish after retries: " + status.error_message()); +} + +Status PythonGcsPublisher::PublishError(const std::string &key_id, + const rpc::ErrorTableData &error_info, + int64_t num_retries) { + rpc::GcsPublishRequest request; + auto *message = request.add_pub_messages(); + message->set_channel_type(rpc::RAY_ERROR_INFO_CHANNEL); + message->set_key_id(key_id); + message->mutable_error_info_message()->MergeFrom(error_info); + return DoPublishWithRetries(request, num_retries, 1000); +} + +Status PythonGcsPublisher::PublishLogs(const std::string &key_id, + const rpc::LogBatch &log_batch) { + rpc::GcsPublishRequest request; + auto *message = request.add_pub_messages(); + message->set_channel_type(rpc::RAY_LOG_CHANNEL); + message->set_key_id(key_id); + message->mutable_log_batch_message()->MergeFrom(log_batch); + return DoPublishWithRetries(request, -1, -1); +} + +Status PythonGcsPublisher::PublishFunctionKey( + const rpc::PythonFunction &python_function) { + rpc::GcsPublishRequest request; + auto *message = request.add_pub_messages(); + message->set_channel_type(rpc::RAY_PYTHON_FUNCTION_CHANNEL); + message->mutable_python_function_message()->MergeFrom(python_function); + return DoPublishWithRetries(request, -1, -1); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/pubsub/gcs_pub_sub.h b/src/ray/gcs/pubsub/gcs_pub_sub.h index ffd79a6adfab..db621938dc98 100644 --- a/src/ray/gcs/pubsub/gcs_pub_sub.h +++ b/src/ray/gcs/pubsub/gcs_pub_sub.h @@ -25,6 +25,7 @@ #include "ray/pubsub/publisher.h" #include "ray/pubsub/subscriber.h" #include "src/ray/protobuf/gcs.pb.h" +#include "src/ray/protobuf/gcs_service.grpc.pb.h" #include "src/ray/protobuf/gcs_service.pb.h" namespace ray { @@ -132,5 +133,41 @@ class GcsSubscriber { const std::unique_ptr subscriber_; }; +// This client is only supposed to be used from Cython / Python +class RAY_EXPORT PythonGcsPublisher { + public: + explicit PythonGcsPublisher(const std::string &gcs_address); + + /// Connect to the publisher service of the GCS. + /// This function must be called before calling other functions. + /// + /// \return Status + Status Connect(); + + /// Publish error information to GCS. + Status PublishError(const std::string &key_id, + const rpc::ErrorTableData &data, + int64_t num_retries); + + /// Publish logs to GCS. + Status PublishLogs(const std::string &key_id, const rpc::LogBatch &log_batch); + + /// Publish a function key to GCS. + Status PublishFunctionKey(const rpc::PythonFunction &python_function); + + private: + Status DoPublishWithRetries(const rpc::GcsPublishRequest &request, + int64_t num_retries, + int64_t timeout_ms); + std::unique_ptr pubsub_stub_; + std::shared_ptr channel_; + std::string gcs_address_; + int gcs_port_; +}; + +/// Construct the arguments for synchronous gRPC clients +/// (the ones wrapped in Python) +grpc::ChannelArguments PythonGrpcChannelArguments(); + } // namespace gcs } // namespace ray