Skip to content

Commit

Permalink
Move datastructures in GlobalStateAccessor to Cython (#34706)
Browse files Browse the repository at this point in the history
More progress along the lines of #33769 to remove Python gRPC from Ray Core.
  • Loading branch information
pcmoritz authored Apr 26, 2023
1 parent aba2971 commit 3e04104
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 55 deletions.
8 changes: 4 additions & 4 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ def __init__(
self.gcs_address,
self._raylet_ip_address,
)
self._plasma_store_socket_name = node_info.object_store_socket_name
self._raylet_socket_name = node_info.raylet_socket_name
self._ray_params.node_manager_port = node_info.node_manager_port
self._plasma_store_socket_name = node_info["object_store_socket_name"]
self._raylet_socket_name = node_info["raylet_socket_name"]
self._ray_params.node_manager_port = node_info["node_manager_port"]
else:
# If the user specified a socket name, use it.
self._plasma_store_socket_name = self._prepare_socket_file(
Expand Down Expand Up @@ -304,7 +304,7 @@ def __init__(
self._raylet_ip_address,
)
if self._ray_params.node_manager_port == 0:
self._ray_params.node_manager_port = node_info.node_manager_port
self._ray_params.node_manager_port = node_info["node_manager_port"]

# Makes sure the Node object has valid addresses after setup.
self.validate_ip_port(self.address)
Expand Down
30 changes: 2 additions & 28 deletions python/ray/_private/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,32 +147,7 @@ def node_table(self):
"""
self._check_connected()

node_table = self.global_state_accessor.get_node_table()

results = []
for node_info_item in node_table:
item = gcs_utils.GcsNodeInfo.FromString(node_info_item)
node_info = {
"NodeID": ray._private.utils.binary_to_hex(item.node_id),
"Alive": item.state
== gcs_utils.GcsNodeInfo.GcsNodeState.Value("ALIVE"),
"NodeManagerAddress": item.node_manager_address,
"NodeManagerHostname": item.node_manager_hostname,
"NodeManagerPort": item.node_manager_port,
"ObjectManagerPort": item.object_manager_port,
"ObjectStoreSocketName": item.object_store_socket_name,
"RayletSocketName": item.raylet_socket_name,
"MetricsExportPort": item.metrics_export_port,
"NodeName": item.node_name,
}
node_info["alive"] = node_info["Alive"]
node_info["Resources"] = (
{key: value for key, value in item.resources_total.items()}
if node_info["Alive"]
else {}
)
results.append(node_info)
return results
return self.global_state_accessor.get_node_table()

def job_table(self):
"""Fetch and parse the gcs job table.
Expand Down Expand Up @@ -749,10 +724,9 @@ def get_system_config(self):
def get_node_to_connect_for_driver(self, node_ip_address):
"""Get the node to connect for a Ray driver."""
self._check_connected()
node_info_str = self.global_state_accessor.get_node_to_connect_for_driver(
return self.global_state_accessor.get_node_to_connect_for_driver(
node_ip_address
)
return gcs_utils.GcsNodeInfo.FromString(node_info_str)


state = GlobalState()
Expand Down
15 changes: 15 additions & 0 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
CRayStatus GetAllJobInfo(
int64_t timeout_ms, c_vector[CJobTableData]& result)

cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
unordered_map[c_string, double] PythonGetResourcesTotal(
const CGcsNodeInfo& node_info)

cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
cdef cppclass CJobConfig "ray::rpc::JobConfig":
c_string ray_namespace() const
Expand All @@ -351,6 +355,17 @@ cdef extern from "src/ray/protobuf/gcs.pb.h" nogil:
c_string node_id() const
c_string node_name() const
int state() const
c_string node_manager_address() const
c_string node_manager_hostname() const
int node_manager_port() const
int object_manager_port() const
c_string object_store_socket_name() const
c_string raylet_socket_name() const
int metrics_export_port() const
void ParseFromString(const c_string &serialized)

cdef enum CGcsNodeState "ray::rpc::GcsNodeInfo_GcsNodeState":
ALIVE "ray::rpc::GcsNodeInfo_GcsNodeState_ALIVE",

cdef cppclass CJobTableData "ray::rpc::JobTableData":
c_string job_id() const
Expand Down
48 changes: 42 additions & 6 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ray.includes.common cimport (
CGcsClientOptions
CGcsClientOptions,
CGcsNodeState,
PythonGetResourcesTotal
)

from ray.includes.unique_ids cimport (
Expand Down Expand Up @@ -51,10 +53,38 @@ cdef class GlobalStateAccessor:
return cjob_id.ToInt()

def get_node_table(self):
cdef c_vector[c_string] result
with nogil:
result = self.inner.get().GetAllNodeInfo()
return result
cdef:
c_vector[c_string] items
c_string item
CGcsNodeInfo c_node_info
unordered_map[c_string, double] c_resources
with nogil:
items = self.inner.get().GetAllNodeInfo()
results = []
for item in items:
c_node_info.ParseFromString(item)
node_info = {
"NodeID": ray._private.utils.binary_to_hex(c_node_info.node_id()),
"Alive": c_node_info.state() == CGcsNodeState.ALIVE,
"NodeManagerAddress": c_node_info.node_manager_address().decode(),
"NodeManagerHostname": c_node_info.node_manager_hostname().decode(),
"NodeManagerPort": c_node_info.node_manager_port(),
"ObjectManagerPort": c_node_info.object_manager_port(),
"ObjectStoreSocketName":
c_node_info.object_store_socket_name().decode(),
"RayletSocketName": c_node_info.raylet_socket_name().decode(),
"MetricsExportPort": c_node_info.metrics_export_port(),
"NodeName": c_node_info.node_name().decode(),
}
node_info["alive"] = node_info["Alive"]
c_resources = PythonGetResourcesTotal(c_node_info)
node_info["Resources"] = (
{key.decode(): value for key, value in c_resources}
if node_info["Alive"]
else {}
)
results.append(node_info)
return results

def get_all_available_resources(self):
cdef c_vector[c_string] result
Expand Down Expand Up @@ -149,9 +179,15 @@ cdef class GlobalStateAccessor:
cdef CRayStatus status
cdef c_string cnode_ip_address = node_ip_address
cdef c_string cnode_to_connect
cdef CGcsNodeInfo c_node_info
with nogil:
status = self.inner.get().GetNodeToConnectForDriver(
cnode_ip_address, &cnode_to_connect)
if not status.ok():
raise RuntimeError(status.message())
return cnode_to_connect
c_node_info.ParseFromString(cnode_to_connect)
return {
"object_store_socket_name": c_node_info.object_store_socket_name().decode(),
"raylet_socket_name": c_node_info.raylet_socket_name().decode(),
"node_manager_port": c_node_info.node_manager_port(),
}
4 changes: 2 additions & 2 deletions python/ray/tests/test_component_failures_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_node_info():
cluster.head_node.node_ip_address,
)

assert get_node_info().raylet_socket_name == cluster.head_node.raylet_socket_name
assert get_node_info()["raylet_socket_name"] == cluster.head_node.raylet_socket_name

cluster.head_node.kill_raylet()
wait_for_condition(
Expand All @@ -137,7 +137,7 @@ def get_node_info():
get_node_info()

node2 = cluster.add_node()
assert get_node_info().raylet_socket_name == node2.raylet_socket_name
assert get_node_info()["raylet_socket_name"] == node2.raylet_socket_name


if __name__ == "__main__":
Expand Down
22 changes: 8 additions & 14 deletions python/ray/tests/test_global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,11 @@ def test_node_name_cluster(ray_start_cluster):
global_state_accessor = make_global_state_accessor(head_context)
node_table = global_state_accessor.get_node_table()
assert len(node_table) == 2
for node_data in node_table:
node = gcs_utils.GcsNodeInfo.FromString(node_data)
if (
ray._private.utils.binary_to_hex(node.node_id)
== head_context.address_info["node_id"]
):
assert node.node_name == "head_node"
for node in node_table:
if node["NodeID"] == head_context.address_info["node_id"]:
assert node["NodeName"] == "head_node"
else:
assert node.node_name == "worker_node"
assert node["NodeName"] == "worker_node"

global_state_accessor.disconnect()
ray.shutdown()
Expand All @@ -188,19 +184,17 @@ def test_node_name_init():
new_head_context = ray.init(_node_name="new_head_node", include_dashboard=False)

global_state_accessor = make_global_state_accessor(new_head_context)
node_data = global_state_accessor.get_node_table()[0]
node = gcs_utils.GcsNodeInfo.FromString(node_data)
assert node.node_name == "new_head_node"
node = global_state_accessor.get_node_table()[0]
assert node["NodeName"] == "new_head_node"
ray.shutdown()


def test_no_node_name():
# Test that starting ray with no node name will result in a node_name=ip_address
new_head_context = ray.init(include_dashboard=False)
global_state_accessor = make_global_state_accessor(new_head_context)
node_data = global_state_accessor.get_node_table()[0]
node = gcs_utils.GcsNodeInfo.FromString(node_data)
assert node.node_name == ray.util.get_node_ip_address()
node = global_state_accessor.get_node_table()[0]
assert node["NodeName"] == ray.util.get_node_ip_address()
ray.shutdown()


Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/test_ray_init_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_ray_init_from_workers(ray_start_cluster):
node_info = ray._private.services.get_node_to_connect_for_driver(
cluster.gcs_address, "127.0.0.3"
)
assert node_info.node_manager_port == node2.node_manager_port
assert node_info["node_manager_port"] == node2.node_manager_port


def test_default_resource_not_allowed_error(shutdown_only):
Expand Down
6 changes: 6 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,5 +398,11 @@ Status PythonGcsClient::GetAllJobInfo(int64_t timeout_ms,
return Status::RpcError(status.error_message(), status.error_code());
}

std::unordered_map<std::string, double> PythonGetResourcesTotal(
const rpc::GcsNodeInfo &node_info) {
return std::unordered_map<std::string, double>(node_info.resources_total().begin(),
node_info.resources_total().end());
}

} // namespace gcs
} // namespace ray
3 changes: 3 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ class RAY_EXPORT PythonGcsClient {
std::shared_ptr<grpc::Channel> channel_;
};

std::unordered_map<std::string, double> PythonGetResourcesTotal(
const rpc::GcsNodeInfo &node_info);

} // namespace gcs

} // namespace ray

0 comments on commit 3e04104

Please sign in to comment.