Skip to content

Commit

Permalink
[GCS] Optimize GetAllJobInfo API for performance (ray-project#47530)
Browse files Browse the repository at this point in the history
Signed-off-by: liuxsh9 <[email protected]>
Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
liuxsh9 authored and ujjawal-khare committed Oct 15, 2024
1 parent d430d6f commit 98ce570
Show file tree
Hide file tree
Showing 19 changed files with 176 additions and 87 deletions.
14 changes: 12 additions & 2 deletions python/ray/_private/gcs_aio_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import logging
from functools import partial
from typing import Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor
from ray._raylet import GcsClient, NewGcsClient, JobID
Expand Down Expand Up @@ -86,7 +87,8 @@ def __init__(self, inner, loop, executor):

def _function_to_async(self, func):
async def wrapper(*args, **kwargs):
return await self.loop.run_in_executor(self.executor, func, *args, **kwargs)
partial_func = partial(func, *args, **kwargs)
return await self.loop.run_in_executor(self.executor, partial_func)

return wrapper

Expand Down Expand Up @@ -204,10 +206,18 @@ async def internal_kv_keys(

async def get_all_job_info(
self,
*,
job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None,
) -> Dict[JobID, gcs_pb2.JobTableData]:
"""
Return dict key: bytes of job_id; value: JobTableData pb message.
"""
return await self._async_proxy.get_all_job_info(job_or_submission_id, timeout)
return await self._async_proxy.get_all_job_info(
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=skip_submission_job_info_field,
skip_is_running_tasks_field=skip_is_running_tasks_field,
timeout=timeout,
)
4 changes: 3 additions & 1 deletion python/ray/_private/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def job_table(self):
"""
self._check_connected()

job_table = self.global_state_accessor.get_job_table()
job_table = self.global_state_accessor.get_job_table(
skip_submission_job_info_field=True, skip_is_running_tasks_field=True
)

results = []
for i in range(len(job_table)):
Expand Down
4 changes: 3 additions & 1 deletion python/ray/_private/usage/usage_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def put_cluster_metadata(gcs_client, *, ray_init_cluster) -> None:
def get_total_num_running_jobs_to_report(gcs_client) -> Optional[int]:
"""Return the total number of running jobs in the cluster excluding internal ones"""
try:
result = gcs_client.get_all_job_info()
result = gcs_client.get_all_job_info(
skip_submission_job_info_field=True, skip_is_running_tasks_field=True
)
total_num_running_jobs = 0
for job_info in result.values():
if not job_info.is_dead and not job_info.config.ray_namespace.startswith(
Expand Down
11 changes: 8 additions & 3 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2907,8 +2907,10 @@ cdef class OldGcsClient:
return result

@_auto_reconnect
def get_all_job_info(self, job_or_submission_id: str = None,
timeout=None) -> Dict[JobID, JobTableData]:
def get_all_job_info(
self, *, job_or_submission_id: str = None, skip_submission_job_info_field=False,
skip_is_running_tasks_field=False, timeout=None
) -> Dict[JobID, JobTableData]:
# Ideally we should use json_format.MessageToDict(job_info),
# but `job_info` is a cpp pb message not a python one.
# Manually converting each and every protobuf field is out of question,
Expand All @@ -2917,6 +2919,8 @@ cdef class OldGcsClient:
c_string c_job_or_submission_id
optional[c_string] c_optional_job_or_submission_id = nullopt
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
CJobTableData c_job_info
c_vector[CJobTableData] c_job_infos
c_vector[c_string] serialized_job_infos
Expand All @@ -2926,7 +2930,8 @@ cdef class OldGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
check_status(self.inner.get().GetAllJobInfo(
c_optional_job_or_submission_id, timeout_ms, c_job_infos))
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, timeout_ms, c_job_infos))
for c_job_info in c_job_infos:
serialized_job_infos.push_back(c_job_info.SerializeAsString())
result = {}
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dashboard/modules/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ async def get_driver_jobs(
jobs with the job id or submission id.
"""
job_infos = await gcs_aio_client.get_all_job_info(
job_or_submission_id=job_or_submission_id, timeout=timeout
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=True,
skip_is_running_tasks_field=True,
timeout=timeout,
)
# Sort jobs from GCS to follow convention of returning only last driver
# of submission job.
Expand Down
6 changes: 5 additions & 1 deletion python/ray/dashboard/modules/snapshot/snapshot_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ async def _get_job_activity_info(self, timeout: int) -> RayActivityResponse:
# This includes the _ray_internal_dashboard job that gets automatically
# created with every cluster
try:
reply = await self._gcs_aio_client.get_all_job_info(timeout=timeout)
reply = await self._gcs_aio_client.get_all_job_info(
skip_submission_job_info_field=True,
skip_is_running_tasks_field=True,
timeout=timeout,
)

num_active_drivers = 0
latest_job_end_time = 0
Expand Down
9 changes: 7 additions & 2 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,15 @@ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil:
cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor":
CRayStatus GetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
c_vector[CJobTableData] &result,
int64_t timeout_ms)

CRayStatus AsyncGetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
const MultiItemPyCallback[CJobTableData] &callback,
int64_t timeout_ms)

Expand Down Expand Up @@ -625,8 +629,9 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
CRayStatus GetAllNodeInfo(
int64_t timeout_ms, c_vector[CGcsNodeInfo]& result)
CRayStatus GetAllJobInfo(
const optional[c_string] &job_or_submission_id, int64_t timeout_ms,
c_vector[CJobTableData]& result)
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field,
int64_t timeout_ms, c_vector[CJobTableData]& result)
CRayStatus GetAllResourceUsage(
int64_t timeout_ms, c_string& serialized_reply)
CRayStatus RequestClusterResourceConstraint(
Expand Down
17 changes: 14 additions & 3 deletions python/ray/includes/gcs_client.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,16 @@ cdef class NewGcsClient:
#############################################################

def get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Dict[JobID, gcs_pb2.JobTableData]:
cdef c_string c_job_or_submission_id
cdef optional[c_string] c_optional_job_or_submission_id = nullopt
cdef int64_t timeout_ms = round(1000 * timeout) if timeout else -1
cdef c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
cdef c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
cdef CRayStatus status
cdef c_vector[CJobTableData] reply
if job_or_submission_id:
Expand All @@ -447,17 +451,22 @@ cdef class NewGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
status = self.inner.get().Jobs().GetAll(
c_optional_job_or_submission_id, reply, timeout_ms)
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, reply, timeout_ms)
return raise_or_return((convert_get_all_job_info(status, move(reply))))

def async_get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Future[Dict[JobID, gcs_pb2.JobTableData]]:
cdef:
c_string c_job_or_submission_id
optional[c_string] c_optional_job_or_submission_id = nullopt
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
fut = incremented_fut()
if job_or_submission_id:
c_job_or_submission_id = job_or_submission_id
Expand All @@ -467,6 +476,8 @@ cdef class NewGcsClient:
check_status_timeout_as_rpc_error(
self.inner.get().Jobs().AsyncGetAll(
c_optional_job_or_submission_id,
c_skip_submission_job_info_field,
c_skip_is_running_tasks_field,
MultiItemPyCallback[CJobTableData](
&convert_get_all_job_info,
assign_and_decrement_fut,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/includes/global_state_accessor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
CGlobalStateAccessor(const CGcsClientOptions&)
c_bool Connect()
void Disconnect()
c_vector[c_string] GetAllJobInfo()
c_vector[c_string] GetAllJobInfo(
c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field)
CJobID GetNextJobID()
c_vector[c_string] GetAllNodeInfo()
c_vector[c_string] GetAllAvailableResources()
Expand Down
10 changes: 8 additions & 2 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ cdef class GlobalStateAccessor:
with nogil:
self.inner.get().Disconnect()

def get_job_table(self):
def get_job_table(
self, *, skip_submission_job_info_field=False, skip_is_running_tasks_field=False
):
cdef c_vector[c_string] result
cdef c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
cdef c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field

with nogil:
result = self.inner.get().GetAllJobInfo()
result = self.inner.get().GetAllJobInfo(
c_skip_submission_job_info_field, c_skip_is_running_tasks_field)
return result

def get_next_job_id(self):
Expand Down
2 changes: 2 additions & 0 deletions src/mock/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class MockJobInfoAccessor : public JobInfoAccessor {
MOCK_METHOD(Status,
AsyncGetAll,
(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms),
(override));
Expand Down
15 changes: 13 additions & 2 deletions src/ray/gcs/gcs_client/accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ Status JobInfoAccessor::AsyncSubscribeAll(
done(status);
}
};
RAY_CHECK_OK(
AsyncGetAll(/*job_or_submission_id=*/std::nullopt, callback, /*timeout_ms=*/-1));
RAY_CHECK_OK(AsyncGetAll(/*job_or_submission_id=*/std::nullopt,
/*skip_submission_job_info_field=*/true,
/*skip_is_running_tasks_field=*/true,
callback,
/*timeout_ms=*/-1));
};
subscribe_operation_ = [this, subscribe](const StatusCallback &done) {
return client_impl_->GetGcsSubscriber().SubscribeAllJobs(subscribe, done);
Expand All @@ -108,11 +111,15 @@ void JobInfoAccessor::AsyncResubscribe() {

Status JobInfoAccessor::AsyncGetAll(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms) {
RAY_LOG(DEBUG) << "Getting all job info.";
RAY_CHECK(callback);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand All @@ -127,9 +134,13 @@ Status JobInfoAccessor::AsyncGetAll(
}

Status JobInfoAccessor::GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms) {
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class JobInfoAccessor {
/// \param callback Callback that will be called after lookup finished.
/// \return Status
virtual Status AsyncGetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms);

Expand All @@ -272,6 +274,8 @@ class JobInfoAccessor {
/// \param timeout_ms -1 means infinite.
/// \return Status
virtual Status GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms);

Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,17 @@ Status PythonGcsClient::GetAllNodeInfo(int64_t timeout_ms,

Status PythonGcsClient::GetAllJobInfo(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result) {
grpc::ClientContext context;
PrepareContext(context, timeout_ms);

absl::ReaderMutexLock lock(&mutex_);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class RAY_EXPORT PythonGcsClient {
Status PinRuntimeEnvUri(const std::string &uri, int expiration_s, int64_t timeout_ms);
Status GetAllNodeInfo(int64_t timeout_ms, std::vector<rpc::GcsNodeInfo> &result);
Status GetAllJobInfo(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result);
Status GetAllResourceUsage(int64_t timeout_ms, std::string &serialized_reply);
Expand Down
5 changes: 4 additions & 1 deletion src/ray/gcs/gcs_client/global_state_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ void GlobalStateAccessor::Disconnect() {
}
}

std::vector<std::string> GlobalStateAccessor::GetAllJobInfo() {
std::vector<std::string> GlobalStateAccessor::GetAllJobInfo(
bool skip_submission_job_info_field, bool skip_is_running_tasks_field) {
// This method assumes GCS is HA and does not return any error. On GCS down, it
// retries indefinitely.
std::vector<std::string> job_table_data;
Expand All @@ -68,6 +69,8 @@ std::vector<std::string> GlobalStateAccessor::GetAllJobInfo() {
absl::ReaderMutexLock lock(&mutex_);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncGetAll(
/*job_or_submission_id=*/std::nullopt,
skip_submission_job_info_field,
skip_is_running_tasks_field,
TransformForMultiItemCallback<rpc::JobTableData>(job_table_data, promise),
/*timeout_ms=*/-1));
}
Expand Down
4 changes: 3 additions & 1 deletion src/ray/gcs/gcs_client/global_state_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class GlobalStateAccessor {
/// \return All job info. To support multi-language, we serialize each JobTableData and
/// return the serialized string. Where used, it needs to be deserialized with
/// protobuf function.
std::vector<std::string> GetAllJobInfo() ABSL_LOCKS_EXCLUDED(mutex_);
std::vector<std::string> GetAllJobInfo(bool skip_submission_job_info_field = false,
bool skip_is_running_tasks_field = false)
ABSL_LOCKS_EXCLUDED(mutex_);

/// Get next job id from GCS Service.
///
Expand Down
Loading

0 comments on commit 98ce570

Please sign in to comment.