Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring back "[Core] Port GcsPublisher to Cython" (#34393) #35179

Merged
merged 4 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .buildkite/pipeline.build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@
- DL=1 ./ci/env/install-dependencies.sh
- bash ./ci/ci.sh prepare_docker
- ./ci/env/env_info.sh
# This is needed or else the Ray Client tests run into a gRPC forking problem
# similar to https://github.com/grpc/grpc/issues/31885
- pip install pip install grpcio==1.50.0
- bazel test --config=ci $(./ci/run/bazel_export_options)
--test_tag_filters=client_tests,small_size_python_tests
-- python/ray/tests/...
Expand Down Expand Up @@ -418,6 +421,9 @@
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DL=1 ./ci/env/install-dependencies.sh
- ./ci/env/env_info.sh
# This is needed or else the Ray Client tests run into a gRPC forking problem
# similar to https://github.com/grpc/grpc/issues/31885
- pip install pip install grpcio==1.50.0
- bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=client_tests,small_size_python_tests
--test_env=TEST_EXTERNAL_REDIS=1
Expand Down
6 changes: 4 additions & 2 deletions dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
from ray._private.gcs_pubsub import GcsAioPublisher
from ray._raylet import GcsClient
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_logging import setup_component_logger
Expand Down Expand Up @@ -263,7 +263,9 @@ async def _check_parent():
ray._private.utils.publish_error_to_driver(
ray_constants.RAYLET_DIED_ERROR,
msg,
gcs_publisher=GcsPublisher(address=self.gcs_address),
gcs_publisher=ray._raylet.GcsPublisher(
address=self.gcs_address
),
)
else:
logger.info(msg)
Expand Down
3 changes: 1 addition & 2 deletions dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.head as dashboard_head
import ray.dashboard.utils as dashboard_utils
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger
from typing import Optional, Set

Expand Down Expand Up @@ -261,7 +260,7 @@ def sigterm_handler():
raise e

# Something went wrong, so push an error to all drivers.
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_DIED_ERROR,
message,
Expand Down
45 changes: 0 additions & 45 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import random
import threading
from typing import Optional, Tuple, List
import time

import grpc
from grpc._channel import _InactiveRpcError
from ray._private.utils import get_or_create_event_loop

try:
Expand Down Expand Up @@ -160,49 +158,6 @@ def _pop_actors(queue, batch_size=100):
return msgs


class GcsPublisher(_PublisherBase):
"""Publisher to GCS."""

def __init__(self, address: str):
channel = gcs_utils.create_gcs_channel(address)
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)

def publish_error(
self, key_id: bytes, error_info: ErrorTableData, num_retries=None
) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info,
)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._gcs_publish(req, num_retries, timeout=1)

def publish_logs(self, log_batch: dict) -> None:
"""Publishes logs to GCS."""
req = self._create_log_request(log_batch)
self._gcs_publish(req)

def publish_function_key(self, key: bytes) -> None:
"""Publishes function key to GCS."""
req = self._create_function_key_request(key)
self._gcs_publish(req)

def _gcs_publish(self, req, num_retries=None, timeout=None) -> None:
count = num_retries or MAX_GCS_PUBLISH_RETRIES
while count > 0:
try:
self._stub.GcsPublish(req, timeout=timeout)
return
except _InactiveRpcError:
pass
count -= 1
if count > 0:
time.sleep(1)
raise TimeoutError(f"Failed to publish after retries: {req}")


class _SyncSubscriber(_SubscriberBase):
def __init__(
self,
Expand Down
8 changes: 3 additions & 5 deletions python/ray/_private/log_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import traceback
from typing import Callable, List, Set

import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.ray_constants as ray_constants
import ray._private.services as services
import ray._private.utils
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger

# Logger for this module. It should be configured at the entry point
Expand Down Expand Up @@ -135,7 +133,7 @@ class LogMonitor:
def __init__(
self,
logs_dir,
gcs_publisher: gcs_pubsub.GcsPublisher,
gcs_publisher,
is_proc_alive_fn: Callable[[int], bool],
max_files_open: int = ray_constants.LOG_MONITOR_MAX_OPEN_FILES,
):
Expand Down Expand Up @@ -525,14 +523,14 @@ def is_proc_alive(pid):
)

log_monitor = LogMonitor(
args.logs_dir, gcs_pubsub.GcsPublisher(address=args.gcs_address), is_proc_alive
args.logs_dir, ray._raylet.GcsPublisher(address=args.gcs_address), is_proc_alive
)

try:
log_monitor.run()
except Exception as e:
# Something went wrong, so push an error to all drivers.
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address)
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
message = (
f"The log monitor on node {platform.node()} "
Expand Down
29 changes: 4 additions & 25 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import ray
import ray._private.ray_constants as ray_constants
from ray._private.tls_utils import load_certs_from_env
from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvInfo as ProtoRuntimeEnvInfo,
)
Expand Down Expand Up @@ -182,27 +181,6 @@ def push_error_to_driver(
worker.core_worker.push_error(job_id, error_type, message, time.time())


def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.
Args:
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The ErrorTableData object.
"""
data = ErrorTableData()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data


def publish_error_to_driver(
error_type: str,
message: str,
Expand All @@ -228,11 +206,12 @@ def publish_error_to_driver(
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
error_data = construct_error_message(job_id, error_type, message, time.time())
try:
gcs_publisher.publish_error(job_id.hex().encode(), error_data, num_retries)
gcs_publisher.publish_error(
job_id.hex().encode(), error_type, message, job_id, num_retries
)
except Exception:
logger.exception(f"Failed to publish error {error_data}")
logger.exception(f"Failed to publish error: {message} [type {error_type}]")


def decode(byte_str: str, allow_none: bool = False, encode_type: str = "utf-8"):
Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
GcsErrorSubscriber,
GcsFunctionKeySubscriber,
GcsLogSubscriber,
GcsPublisher,
)
from ray._private.inspect_util import is_cython
from ray._private.ray_logging import (
Expand Down Expand Up @@ -2074,7 +2073,7 @@ def connect(
ray._private.state.state._initialize_global_state(
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
)
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address)
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
# We should not specify the job_id if it's `WORKER_MODE`.
Expand Down
63 changes: 63 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ from ray.includes.common cimport (
CObjectReference,
CRayObject,
CRayStatus,
CErrorTableData,
CGcsClientOptions,
CGcsNodeInfo,
CJobTableData,
CLogBatch,
CTaskArg,
CTaskArgByReference,
CTaskArgByValue,
CTaskType,
CPlacementStrategy,
CPythonFunction,
CSchedulingStrategy,
CPlacementGroupSchedulingStrategy,
CNodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -1742,6 +1745,66 @@ cdef class GcsClient:
}
return result

cdef class GcsPublisher:
"""Cython wrapper class of C++ `ray::gcs::PythonGcsPublisher`."""
cdef:
shared_ptr[CPythonGcsPublisher] inner

def __cinit__(self, address):
self.inner.reset(new CPythonGcsPublisher(address))
check_status(self.inner.get().Connect())

def publish_error(self, key_id: bytes, error_type: str, message: str,
job_id=None, num_retries=None):
cdef:
CErrorTableData error_info
int64_t c_num_retries = num_retries if num_retries else -1
c_string c_key_id = key_id

if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
error_info.set_job_id(job_id.binary())
error_info.set_type(error_type)
error_info.set_error_message(message)
error_info.set_timestamp(time.time())

with nogil:
check_status(
self.inner.get().PublishError(c_key_id, error_info, c_num_retries))

def publish_logs(self, log_json: dict):
cdef:
CLogBatch log_batch
c_string c_job_id

job_id = log_json.get("job")
log_batch.set_ip(log_json.get("ip") if log_json.get("ip") else b"")
log_batch.set_pid(
str(log_json.get("pid")).encode() if log_json.get("pid") else b"")
log_batch.set_job_id(job_id.encode() if job_id else b"")
log_batch.set_is_error(bool(log_json.get("is_err")))
for line in log_json.get("lines", []):
log_batch.add_lines(line)
actor_name = log_json.get("actor_name")
log_batch.set_actor_name(actor_name.encode() if actor_name else b"")
task_name = log_json.get("task_name")
log_batch.set_task_name(task_name.encode() if task_name else b"")

c_job_id = job_id.encode() if job_id else b""
with nogil:
check_status(self.inner.get().PublishLogs(c_job_id, log_batch))

def publish_function_key(self, key: bytes):
cdef:
CPythonFunction python_function

python_function.set_key(key)

with nogil:
check_status(self.inner.get().PublishFunctionKey(python_function))


cdef class CoreWorker:

def __cinit__(self, worker_type, store_socket, raylet_socket,
Expand Down
3 changes: 1 addition & 2 deletions python/ray/autoscaler/_private/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ray._private.ray_constants as ray_constants
import ray._private.utils
from ray._private.event.event_logger import get_event_logger
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger
from ray._raylet import GcsClient
from ray.autoscaler._private.autoscaler import StandardAutoscaler
Expand Down Expand Up @@ -560,7 +559,7 @@ def _handle_failure(self, error):
_internal_kv_put(
ray_constants.DEBUG_AUTOSCALING_ERROR, message, overwrite=True
)
gcs_publisher = GcsPublisher(address=self.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=self.gcs_address)
from ray._private.utils import publish_error_to_driver

publish_error_to_driver(
Expand Down
38 changes: 38 additions & 0 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,21 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
unordered_map[c_string, double] PythonGetResourcesTotal(
const CGcsNodeInfo& node_info)

cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" nogil:

cdef cppclass CPythonGcsPublisher "ray::gcs::PythonGcsPublisher":

CPythonGcsPublisher(const c_string& gcs_address)

CRayStatus Connect()

CRayStatus PublishError(
const c_string &key_id, const CErrorTableData &data, int64_t num_retries)

CRayStatus PublishLogs(const c_string &key_id, const CLogBatch &data)

CRayStatus PublishFunctionKey(const CPythonFunction& python_function)

cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
cdef cppclass CJobConfig "ray::rpc::JobConfig":
c_string ray_namespace() const
Expand All @@ -372,6 +387,29 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
c_bool is_dead() const
CJobConfig config() const

cdef cppclass CPythonFunction "ray::rpc::PythonFunction":
void set_key(const c_string &key)

cdef cppclass CErrorTableData "ray::rpc::ErrorTableData":
c_string job_id() const
c_string type() const
c_string error_message() const
double timestamp() const

void set_job_id(const c_string &job_id)
void set_type(const c_string &type)
void set_error_message(const c_string &error_message)
void set_timestamp(double timestamp)

cdef cppclass CLogBatch "ray::rpc::LogBatch":
void set_ip(const c_string &ip)
void set_pid(const c_string &pid)
void set_job_id(const c_string &job_id)
void set_is_error(c_bool is_error)
void add_lines(const c_string &line)
void set_actor_name(const c_string &actor_name)
void set_task_name(const c_string &task_name)


cdef extern from "ray/common/task/task_spec.h" nogil:
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/common.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from ray.includes.common cimport (
CObjectLocation,
CGcsClientOptions,
CPythonGcsClient,
CPythonGcsPublisher,
)


Expand Down
3 changes: 3 additions & 0 deletions python/ray/tests/test_basic_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def sys_path():
subprocess.check_call(["python", "-m", "package.module2"])


# This will be fixed on Windows once the import thread is removed, see
# https://github.com/ray-project/ray/pull/30895
@pytest.mark.skipif(sys.platform == "win32", reason="Currently fails on Windows.")
def test_worker_kv_calls(monkeypatch, shutdown_only):
monkeypatch.setenv("TEST_RAY_COLLECT_KV_FREQUENCY", "1")
ray.init()
Expand Down
Loading