Skip to content

Commit

Permalink
Bring back "[Core] Port GcsPublisher to Cython" (#34393) (#35179)
Browse files Browse the repository at this point in the history
I spent quite a bit of time debugging the test failure in #34393 (see also #35108)

It turns out the PR slightly made the _do_importing race condition (first time call in the import thread) more likely to happen. There is already a plan / PR to get rid of it (#30895) but it is currently waiting for having a replacement mechanism that @rkooo567 is working on.

I synced with @scv119 and for the time being, we are planning to skip the offending test on Windows and once we got rid of the import thread, we can re-activate it.
  • Loading branch information
pcmoritz authored May 11, 2023
1 parent 2ddebc0 commit 81a4a5e
Show file tree
Hide file tree
Showing 18 changed files with 271 additions and 109 deletions.
6 changes: 6 additions & 0 deletions .buildkite/pipeline.build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@
- DL=1 ./ci/env/install-dependencies.sh
- bash ./ci/ci.sh prepare_docker
- ./ci/env/env_info.sh
# This is needed or else the Ray Client tests run into a gRPC forking problem
# similar to https://github.com/grpc/grpc/issues/31885
- pip install pip install grpcio==1.50.0
- bazel test --config=ci $(./ci/run/bazel_export_options)
--test_tag_filters=client_tests,small_size_python_tests
-- python/ray/tests/...
Expand Down Expand Up @@ -418,6 +421,9 @@
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- DL=1 ./ci/env/install-dependencies.sh
- ./ci/env/env_info.sh
# This is needed or else the Ray Client tests run into a gRPC forking problem
# similar to https://github.com/grpc/grpc/issues/31885
- pip install pip install grpcio==1.50.0
- bazel test --config=ci $(./scripts/bazel_export_options)
--test_tag_filters=client_tests,small_size_python_tests
--test_env=TEST_EXTERNAL_REDIS=1
Expand Down
6 changes: 4 additions & 2 deletions dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
from ray._private.gcs_pubsub import GcsAioPublisher
from ray._raylet import GcsClient
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_logging import setup_component_logger
Expand Down Expand Up @@ -263,7 +263,9 @@ async def _check_parent():
ray._private.utils.publish_error_to_driver(
ray_constants.RAYLET_DIED_ERROR,
msg,
gcs_publisher=GcsPublisher(address=self.gcs_address),
gcs_publisher=ray._raylet.GcsPublisher(
address=self.gcs_address
),
)
else:
logger.info(msg)
Expand Down
3 changes: 1 addition & 2 deletions dashboard/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ray.dashboard.consts as dashboard_consts
import ray.dashboard.head as dashboard_head
import ray.dashboard.utils as dashboard_utils
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger
from typing import Optional, Set

Expand Down Expand Up @@ -261,7 +260,7 @@ def sigterm_handler():
raise e

# Something went wrong, so push an error to all drivers.
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address)
ray._private.utils.publish_error_to_driver(
ray_constants.DASHBOARD_DIED_ERROR,
message,
Expand Down
45 changes: 0 additions & 45 deletions python/ray/_private/gcs_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import random
import threading
from typing import Optional, Tuple, List
import time

import grpc
from grpc._channel import _InactiveRpcError
from ray._private.utils import get_or_create_event_loop

try:
Expand Down Expand Up @@ -160,49 +158,6 @@ def _pop_actors(queue, batch_size=100):
return msgs


class GcsPublisher(_PublisherBase):
"""Publisher to GCS."""

def __init__(self, address: str):
channel = gcs_utils.create_gcs_channel(address)
self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)

def publish_error(
self, key_id: bytes, error_info: ErrorTableData, num_retries=None
) -> None:
"""Publishes error info to GCS."""
msg = pubsub_pb2.PubMessage(
channel_type=pubsub_pb2.RAY_ERROR_INFO_CHANNEL,
key_id=key_id,
error_info_message=error_info,
)
req = gcs_service_pb2.GcsPublishRequest(pub_messages=[msg])
self._gcs_publish(req, num_retries, timeout=1)

def publish_logs(self, log_batch: dict) -> None:
"""Publishes logs to GCS."""
req = self._create_log_request(log_batch)
self._gcs_publish(req)

def publish_function_key(self, key: bytes) -> None:
"""Publishes function key to GCS."""
req = self._create_function_key_request(key)
self._gcs_publish(req)

def _gcs_publish(self, req, num_retries=None, timeout=None) -> None:
count = num_retries or MAX_GCS_PUBLISH_RETRIES
while count > 0:
try:
self._stub.GcsPublish(req, timeout=timeout)
return
except _InactiveRpcError:
pass
count -= 1
if count > 0:
time.sleep(1)
raise TimeoutError(f"Failed to publish after retries: {req}")


class _SyncSubscriber(_SubscriberBase):
def __init__(
self,
Expand Down
8 changes: 3 additions & 5 deletions python/ray/_private/log_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@
import traceback
from typing import Callable, List, Set

import ray._private.gcs_pubsub as gcs_pubsub
import ray._private.ray_constants as ray_constants
import ray._private.services as services
import ray._private.utils
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger

# Logger for this module. It should be configured at the entry point
Expand Down Expand Up @@ -135,7 +133,7 @@ class LogMonitor:
def __init__(
self,
logs_dir,
gcs_publisher: gcs_pubsub.GcsPublisher,
gcs_publisher,
is_proc_alive_fn: Callable[[int], bool],
max_files_open: int = ray_constants.LOG_MONITOR_MAX_OPEN_FILES,
):
Expand Down Expand Up @@ -525,14 +523,14 @@ def is_proc_alive(pid):
)

log_monitor = LogMonitor(
args.logs_dir, gcs_pubsub.GcsPublisher(address=args.gcs_address), is_proc_alive
args.logs_dir, ray._raylet.GcsPublisher(address=args.gcs_address), is_proc_alive
)

try:
log_monitor.run()
except Exception as e:
# Something went wrong, so push an error to all drivers.
gcs_publisher = GcsPublisher(address=args.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=args.gcs_address)
traceback_str = ray._private.utils.format_error_message(traceback.format_exc())
message = (
f"The log monitor on node {platform.node()} "
Expand Down
29 changes: 4 additions & 25 deletions python/ray/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import ray
import ray._private.ray_constants as ray_constants
from ray._private.tls_utils import load_certs_from_env
from ray.core.generated.gcs_pb2 import ErrorTableData
from ray.core.generated.runtime_env_common_pb2 import (
RuntimeEnvInfo as ProtoRuntimeEnvInfo,
)
Expand Down Expand Up @@ -182,27 +181,6 @@ def push_error_to_driver(
worker.core_worker.push_error(job_id, error_type, message, time.time())


def construct_error_message(job_id, error_type, message, timestamp):
"""Construct an ErrorTableData object.
Args:
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
timestamp: The time of the error.
Returns:
The ErrorTableData object.
"""
data = ErrorTableData()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
return data


def publish_error_to_driver(
error_type: str,
message: str,
Expand All @@ -228,11 +206,12 @@ def publish_error_to_driver(
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
error_data = construct_error_message(job_id, error_type, message, time.time())
try:
gcs_publisher.publish_error(job_id.hex().encode(), error_data, num_retries)
gcs_publisher.publish_error(
job_id.hex().encode(), error_type, message, job_id, num_retries
)
except Exception:
logger.exception(f"Failed to publish error {error_data}")
logger.exception(f"Failed to publish error: {message} [type {error_type}]")


def decode(byte_str: str, allow_none: bool = False, encode_type: str = "utf-8"):
Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
GcsErrorSubscriber,
GcsFunctionKeySubscriber,
GcsLogSubscriber,
GcsPublisher,
)
from ray._private.inspect_util import is_cython
from ray._private.ray_logging import (
Expand Down Expand Up @@ -2074,7 +2073,7 @@ def connect(
ray._private.state.state._initialize_global_state(
ray._raylet.GcsClientOptions.from_gcs_address(node.gcs_address)
)
worker.gcs_publisher = GcsPublisher(address=worker.gcs_client.address)
worker.gcs_publisher = ray._raylet.GcsPublisher(address=worker.gcs_client.address)
# Initialize some fields.
if mode in (WORKER_MODE, RESTORE_WORKER_MODE, SPILL_WORKER_MODE):
# We should not specify the job_id if it's `WORKER_MODE`.
Expand Down
63 changes: 63 additions & 0 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,17 @@ from ray.includes.common cimport (
CObjectReference,
CRayObject,
CRayStatus,
CErrorTableData,
CGcsClientOptions,
CGcsNodeInfo,
CJobTableData,
CLogBatch,
CTaskArg,
CTaskArgByReference,
CTaskArgByValue,
CTaskType,
CPlacementStrategy,
CPythonFunction,
CSchedulingStrategy,
CPlacementGroupSchedulingStrategy,
CNodeAffinitySchedulingStrategy,
Expand Down Expand Up @@ -1742,6 +1745,66 @@ cdef class GcsClient:
}
return result

cdef class GcsPublisher:
"""Cython wrapper class of C++ `ray::gcs::PythonGcsPublisher`."""
cdef:
shared_ptr[CPythonGcsPublisher] inner

def __cinit__(self, address):
self.inner.reset(new CPythonGcsPublisher(address))
check_status(self.inner.get().Connect())

def publish_error(self, key_id: bytes, error_type: str, message: str,
job_id=None, num_retries=None):
cdef:
CErrorTableData error_info
int64_t c_num_retries = num_retries if num_retries else -1
c_string c_key_id = key_id

if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
error_info.set_job_id(job_id.binary())
error_info.set_type(error_type)
error_info.set_error_message(message)
error_info.set_timestamp(time.time())

with nogil:
check_status(
self.inner.get().PublishError(c_key_id, error_info, c_num_retries))

def publish_logs(self, log_json: dict):
cdef:
CLogBatch log_batch
c_string c_job_id

job_id = log_json.get("job")
log_batch.set_ip(log_json.get("ip") if log_json.get("ip") else b"")
log_batch.set_pid(
str(log_json.get("pid")).encode() if log_json.get("pid") else b"")
log_batch.set_job_id(job_id.encode() if job_id else b"")
log_batch.set_is_error(bool(log_json.get("is_err")))
for line in log_json.get("lines", []):
log_batch.add_lines(line)
actor_name = log_json.get("actor_name")
log_batch.set_actor_name(actor_name.encode() if actor_name else b"")
task_name = log_json.get("task_name")
log_batch.set_task_name(task_name.encode() if task_name else b"")

c_job_id = job_id.encode() if job_id else b""
with nogil:
check_status(self.inner.get().PublishLogs(c_job_id, log_batch))

def publish_function_key(self, key: bytes):
cdef:
CPythonFunction python_function

python_function.set_key(key)

with nogil:
check_status(self.inner.get().PublishFunctionKey(python_function))


cdef class CoreWorker:

def __cinit__(self, worker_type, store_socket, raylet_socket,
Expand Down
3 changes: 1 addition & 2 deletions python/ray/autoscaler/_private/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import ray._private.ray_constants as ray_constants
import ray._private.utils
from ray._private.event.event_logger import get_event_logger
from ray._private.gcs_pubsub import GcsPublisher
from ray._private.ray_logging import setup_component_logger
from ray._raylet import GcsClient
from ray.autoscaler._private.autoscaler import StandardAutoscaler
Expand Down Expand Up @@ -560,7 +559,7 @@ def _handle_failure(self, error):
_internal_kv_put(
ray_constants.DEBUG_AUTOSCALING_ERROR, message, overwrite=True
)
gcs_publisher = GcsPublisher(address=self.gcs_address)
gcs_publisher = ray._raylet.GcsPublisher(address=self.gcs_address)
from ray._private.utils import publish_error_to_driver

publish_error_to_driver(
Expand Down
38 changes: 38 additions & 0 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,21 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
unordered_map[c_string, double] PythonGetResourcesTotal(
const CGcsNodeInfo& node_info)

cdef extern from "ray/gcs/pubsub/gcs_pub_sub.h" nogil:

cdef cppclass CPythonGcsPublisher "ray::gcs::PythonGcsPublisher":

CPythonGcsPublisher(const c_string& gcs_address)

CRayStatus Connect()

CRayStatus PublishError(
const c_string &key_id, const CErrorTableData &data, int64_t num_retries)

CRayStatus PublishLogs(const c_string &key_id, const CLogBatch &data)

CRayStatus PublishFunctionKey(const CPythonFunction& python_function)

cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
cdef cppclass CJobConfig "ray::rpc::JobConfig":
c_string ray_namespace() const
Expand All @@ -372,6 +387,29 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
c_bool is_dead() const
CJobConfig config() const

cdef cppclass CPythonFunction "ray::rpc::PythonFunction":
void set_key(const c_string &key)

cdef cppclass CErrorTableData "ray::rpc::ErrorTableData":
c_string job_id() const
c_string type() const
c_string error_message() const
double timestamp() const

void set_job_id(const c_string &job_id)
void set_type(const c_string &type)
void set_error_message(const c_string &error_message)
void set_timestamp(double timestamp)

cdef cppclass CLogBatch "ray::rpc::LogBatch":
void set_ip(const c_string &ip)
void set_pid(const c_string &pid)
void set_job_id(const c_string &job_id)
void set_is_error(c_bool is_error)
void add_lines(const c_string &line)
void set_actor_name(const c_string &actor_name)
void set_task_name(const c_string &task_name)


cdef extern from "ray/common/task/task_spec.h" nogil:
cdef cppclass CConcurrencyGroup "ray::ConcurrencyGroup":
Expand Down
1 change: 1 addition & 0 deletions python/ray/includes/common.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ from ray.includes.common cimport (
CObjectLocation,
CGcsClientOptions,
CPythonGcsClient,
CPythonGcsPublisher,
)


Expand Down
3 changes: 3 additions & 0 deletions python/ray/tests/test_basic_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def sys_path():
subprocess.check_call(["python", "-m", "package.module2"])


# This will be fixed on Windows once the import thread is removed, see
# https://github.com/ray-project/ray/pull/30895
@pytest.mark.skipif(sys.platform == "win32", reason="Currently fails on Windows.")
def test_worker_kv_calls(monkeypatch, shutdown_only):
monkeypatch.setenv("TEST_RAY_COLLECT_KV_FREQUENCY", "1")
ray.init()
Expand Down
Loading

0 comments on commit 81a4a5e

Please sign in to comment.