Skip to content

Commit

Permalink
Add Cython wrapper for GcsClient (#33769)
Browse files Browse the repository at this point in the history
This is with the eventual goal of removing Python gRPC calls from Ray Core / Python workers. As a first cut, I'm removing the Python GcsClient.

This PR introduces a Cython GcsClient that wraps a simple C++ synchronous GCS client. As a result, the code for the GcsClient moves from `ray._private.gcs_utils` to `ray._raylet`. The existing Python level reconnection logic `_auto_reconnect` is reused almost without changes.

This new Cython client can support the full use cases of the old pure Python `GcsClient` and is (almost) a drop in replacement. To make sure this is indeed the case, this PR also switches over all the uses of the old client and removes the old code.

We also introduce a new exception type `ray.exceptions.RpcError` which is a replacement of `grpc.RpcError` and allows the Python level code that does exception handling to keep working.
  • Loading branch information
pcmoritz authored Apr 12, 2023
1 parent b01f7bc commit 7c9da5c
Show file tree
Hide file tree
Showing 43 changed files with 681 additions and 312 deletions.
3 changes: 2 additions & 1 deletion dashboard/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import ray.dashboard.utils as dashboard_utils
from ray.dashboard.consts import _PARENT_DEATH_THREASHOLD
from ray._private.gcs_pubsub import GcsAioPublisher, GcsPublisher
from ray._private.gcs_utils import GcsAioClient, GcsClient
from ray._raylet import GcsClient
from ray._private.gcs_utils import GcsAioClient
from ray._private.ray_logging import setup_component_logger
from ray.core.generated import agent_manager_pb2, agent_manager_pb2_grpc
from ray.experimental.internal_kv import (
Expand Down
3 changes: 2 additions & 1 deletion dashboard/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from ray._private.usage.usage_lib import TagKey, record_extra_usage_tag
from ray._private import ray_constants
from ray.dashboard.utils import DashboardHeadModule
from ray._private.gcs_utils import GcsClient, GcsAioClient, check_health
from ray._raylet import GcsClient
from ray._private.gcs_utils import GcsAioClient, check_health
from ray.dashboard.datacenter import DataOrganizer
from ray.dashboard.utils import async_loop_forever
from ray.dashboard.consts import DASHBOARD_METRIC_PORT
Expand Down
2 changes: 1 addition & 1 deletion dashboard/http_server_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# installation must be included in this file. This allows us to determine if
# the agent has the necessary dependencies to be started.
from ray.dashboard.optional_deps import aiohttp, hdrs
from ray._private.gcs_utils import GcsClient
from ray._raylet import GcsClient


# Logger for this module. It should be configured at the entry point
Expand Down
2 changes: 1 addition & 1 deletion dashboard/tests/test_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

def make_gcs_client(address_info):
address = address_info["gcs_address"]
gcs_client = ray._private.gcs_utils.GcsClient(address=address)
gcs_client = ray._raylet.GcsClient(address=address)
return gcs_client


Expand Down
2 changes: 1 addition & 1 deletion dashboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import ray
import ray._private.ray_constants as ray_constants
import ray._private.services as services
from ray._private.gcs_utils import GcsClient
from ray._raylet import GcsClient
from ray._private.utils import split_address

import aiosignal # noqa: F401
Expand Down
225 changes: 4 additions & 221 deletions python/ray/_private/gcs_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import enum
import logging
import time
import traceback
import inspect
import os
Expand Down Expand Up @@ -192,42 +191,10 @@ async def wrapper(self, *args, **kwargs):
return wrapper
else:

@wraps(f)
def wrapper(self, *args, **kwargs):
if "TEST_RAY_COLLECT_KV_FREQUENCY" in os.environ:
global _called_freq
name = f.__name__
if name not in _called_freq:
_called_freq[name] = 0
_called_freq[name] += 1
remaining_retry = self._nums_reconnect_retry
while True:
try:
return f(self, *args, **kwargs)
except grpc.RpcError as e:
if e.code() in (
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.UNKNOWN,
):
if remaining_retry <= 0:
logger.error(
"Failed to connect to GCS. Please check"
" `gcs_server.out` for more details."
)
raise
logger.debug(
"Failed to send request to gcs, reconnecting. " f"Error {e}"
)
try:
self._connect()
except Exception:
logger.error(f"Connecting to gcs failed. Error {e}")
time.sleep(1)
remaining_retry -= 1
continue
raise

return wrapper
raise NotImplementedError(
"This code moved to Cython, see "
"https://github.com/ray-project/ray/pull/33769"
)


class GcsChannel:
Expand Down Expand Up @@ -256,190 +223,6 @@ class GcsCode(enum.IntEnum):
GrpcUnavailable = 26


class GcsClient:
"""Client to GCS using GRPC"""

def __init__(
self,
channel: Optional[GcsChannel] = None,
address: Optional[str] = None,
nums_reconnect_retry: int = 5,
):
if channel is None:
assert isinstance(address, str)
channel = GcsChannel(gcs_address=address)
assert isinstance(channel, GcsChannel)
assert channel._aio is False
self._channel = channel
self._connect()
self._nums_reconnect_retry = nums_reconnect_retry

def _connect(self):
self._channel.connect()
self._kv_stub = gcs_service_pb2_grpc.InternalKVGcsServiceStub(
self._channel.channel()
)
self._runtime_env_stub = gcs_service_pb2_grpc.RuntimeEnvGcsServiceStub(
self._channel.channel()
)
self._node_info_stub = gcs_service_pb2_grpc.NodeInfoGcsServiceStub(
self._channel.channel()
)
self._job_info_stub = gcs_service_pb2_grpc.JobInfoGcsServiceStub(
self._channel.channel()
)

@property
def address(self):
return self._channel._gcs_address

@_auto_reconnect
def internal_kv_get(
self, key: bytes, namespace: Optional[bytes], timeout: Optional[float] = None
) -> Optional[bytes]:
logger.debug(f"internal_kv_get {key!r} {namespace!r}")
req = gcs_service_pb2.InternalKVGetRequest(namespace=namespace, key=key)
reply = self._kv_stub.InternalKVGet(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return reply.value
elif reply.status.code == GcsCode.NotFound:
return None
else:
raise RuntimeError(
f"Failed to get value for key {key!r} "
f"due to error {reply.status.message}"
)

@_auto_reconnect
def internal_kv_multi_get(
self,
keys: List[bytes],
namespace: Optional[bytes],
timeout: Optional[float] = None,
) -> Dict[bytes, bytes]:
logger.debug(f"internal_kv_multi_get {keys!r} {namespace!r}")
req = gcs_service_pb2.InternalKVMultiGetRequest(namespace=namespace, keys=keys)
reply = self._kv_stub.InternalKVMultiGet(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return {entry.key: entry.value for entry in reply.results}
elif reply.status.code == GcsCode.NotFound:
return {}
else:
raise RuntimeError(
f"Failed to get value for key {keys!r} "
f"due to error {reply.status.message}"
)

@_auto_reconnect
def internal_kv_put(
self,
key: bytes,
value: bytes,
overwrite: bool,
namespace: Optional[bytes],
timeout: Optional[float] = None,
) -> int:
logger.debug(f"internal_kv_put {key!r} {value!r} {overwrite} {namespace!r}")
req = gcs_service_pb2.InternalKVPutRequest(
namespace=namespace,
key=key,
value=value,
overwrite=overwrite,
)
reply = self._kv_stub.InternalKVPut(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return reply.added_num
else:
raise RuntimeError(
f"Failed to put value {value!r} to key {key!r} "
f"due to error {reply.status.message}"
)

@_auto_reconnect
def internal_kv_del(
self,
key: bytes,
del_by_prefix: bool,
namespace: Optional[bytes],
timeout: Optional[float] = None,
) -> int:
logger.debug(f"internal_kv_del {key!r} {del_by_prefix} {namespace!r}")
req = gcs_service_pb2.InternalKVDelRequest(
namespace=namespace, key=key, del_by_prefix=del_by_prefix
)
reply = self._kv_stub.InternalKVDel(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return reply.deleted_num
else:
raise RuntimeError(
f"Failed to delete key {key!r} " f"due to error {reply.status.message}"
)

@_auto_reconnect
def internal_kv_exists(
self, key: bytes, namespace: Optional[bytes], timeout: Optional[float] = None
) -> bool:
logger.debug(f"internal_kv_exists {key!r} {namespace!r}")
req = gcs_service_pb2.InternalKVExistsRequest(namespace=namespace, key=key)
reply = self._kv_stub.InternalKVExists(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return reply.exists
else:
raise RuntimeError(
f"Failed to check existence of key {key!r} "
f"due to error {reply.status.message}"
)

@_auto_reconnect
def internal_kv_keys(
self, prefix: bytes, namespace: Optional[bytes], timeout: Optional[float] = None
) -> List[bytes]:
logger.debug(f"internal_kv_keys {prefix!r} {namespace!r}")
req = gcs_service_pb2.InternalKVKeysRequest(namespace=namespace, prefix=prefix)
reply = self._kv_stub.InternalKVKeys(req, timeout=timeout)
if reply.status.code == GcsCode.OK:
return reply.results
else:
raise RuntimeError(
f"Failed to list prefix {prefix!r} "
f"due to error {reply.status.message}"
)

@_auto_reconnect
def pin_runtime_env_uri(self, uri: str, expiration_s: int) -> None:
"""Makes a synchronous call to the GCS to temporarily pin the URI."""
req = gcs_service_pb2.PinRuntimeEnvURIRequest(
uri=uri, expiration_s=expiration_s
)
reply = self._runtime_env_stub.PinRuntimeEnvURI(req)
if reply.status.code == GcsCode.GrpcUnavailable:
raise RuntimeError(
f"Failed to pin URI reference {uri} due to the GCS being "
f"unavailable, most likely it has crashed: {reply.status.message}."
)
elif reply.status.code != GcsCode.OK:
raise RuntimeError(
f"Failed to pin URI reference for {uri} "
f"due to unexpected error {reply.status.message}."
)

@_auto_reconnect
def get_all_node_info(
self, timeout: Optional[float] = None
) -> gcs_service_pb2.GetAllNodeInfoReply:
req = gcs_service_pb2.GetAllNodeInfoRequest()
reply = self._node_info_stub.GetAllNodeInfo(req, timeout=timeout)
return reply

@_auto_reconnect
def get_all_job_info(
self, timeout: Optional[float] = None
) -> gcs_service_pb2.GetAllJobInfoReply:
req = gcs_service_pb2.GetAllJobInfoRequest()
reply = self._job_info_stub.GetAllJobInfo(req, timeout=timeout)
return reply


class GcsAioClient:
def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/metrics_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from opencensus.tags import tag_value as tag_value_module

import ray
from ray._private.gcs_utils import GcsClient
from ray._raylet import GcsClient

from ray.core.generated.metrics_pb2 import Metric

Expand Down
2 changes: 1 addition & 1 deletion python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import ray._private.services
import ray._private.utils
from ray._private import storage
from ray._private.gcs_utils import GcsClient
from ray._raylet import GcsClient
from ray._private.resource_spec import ResourceSpec
from ray._private.utils import open_log, try_to_create_directory, try_to_symlink

Expand Down
3 changes: 1 addition & 2 deletions python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
# Ray modules
import ray
import ray._private.ray_constants as ray_constants
from ray._private.gcs_utils import GcsClient
from ray._raylet import GcsClientOptions
from ray._raylet import GcsClient, GcsClientOptions
from ray.core.generated.common_pb2 import Language

resource = None
Expand Down
14 changes: 7 additions & 7 deletions python/ray/_private/usage/usage_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,10 @@ def get_total_num_running_jobs_to_report(gcs_client) -> Optional[int]:
try:
result = gcs_client.get_all_job_info()
total_num_running_jobs = 0
for job in result.job_info_list:
if not job.is_dead and not job.config.ray_namespace.startswith(
"_ray_internal"
):
for job_id, job_info in result.items():
if not job_info["is_dead"] and not job_info["config"][
"ray_namespace"
].startswith("_ray_internal"):
total_num_running_jobs += 1
return total_num_running_jobs
except Exception as e:
Expand All @@ -485,8 +485,8 @@ def get_total_num_nodes_to_report(gcs_client, timeout=None) -> Optional[int]:
try:
result = gcs_client.get_all_node_info(timeout=timeout)
total_num_nodes = 0
for node in result.node_info_list:
if node.state == gcs_utils.GcsNodeInfo.GcsNodeState.ALIVE:
for node_id, node_info in result.items():
if node_info["state"] == gcs_utils.GcsNodeInfo.GcsNodeState.ALIVE:
total_num_nodes += 1
return total_num_nodes
except Exception as e:
Expand Down Expand Up @@ -728,7 +728,7 @@ def generate_report_data(
Returns:
UsageStats
"""
gcs_client = gcs_utils.GcsClient(address=gcs_address, nums_reconnect_retry=20)
gcs_client = ray._raylet.GcsClient(address=gcs_address, nums_reconnect_retry=20)

cluster_metadata = get_cluster_metadata(gcs_client)
cluster_status_to_report = get_cluster_status_to_report(gcs_client)
Expand Down
Loading

0 comments on commit 7c9da5c

Please sign in to comment.