Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GCS] Optimize GetAllJobInfo API for performance #47530

Merged
merged 12 commits into from
Sep 11, 2024
Merged
14 changes: 12 additions & 2 deletions python/ray/_private/gcs_aio_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import logging
from functools import partial
from typing import Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor
from ray._raylet import GcsClient, NewGcsClient, JobID
Expand Down Expand Up @@ -86,7 +87,8 @@ def __init__(self, inner, loop, executor):

def _function_to_async(self, func):
async def wrapper(*args, **kwargs):
return await self.loop.run_in_executor(self.executor, func, *args, **kwargs)
partial_func = partial(func, *args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Let's avoid bundling up unrelated changes

Copy link
Contributor Author

@liuxsh9 liuxsh9 Sep 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discovered that run_in_executor doesn't support **kwargs ref: asyncio.loop.run_in_executor. The implicit params passing was exposing this issue. For now, we've changed to explicit params passing. So have reverted this change.
Anyway, should we adjust it to accommodate run_in_executor's limitations?

return await self.loop.run_in_executor(self.executor, partial_func)

return wrapper

Expand Down Expand Up @@ -204,10 +206,18 @@ async def internal_kv_keys(

async def get_all_job_info(
self,
*,
job_or_submission_id: Optional[str] = None,
jjyao marked this conversation as resolved.
Show resolved Hide resolved
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None,
) -> Dict[JobID, gcs_pb2.JobTableData]:
"""
Return dict key: bytes of job_id; value: JobTableData pb message.
"""
return await self._async_proxy.get_all_job_info(job_or_submission_id, timeout)
return await self._async_proxy.get_all_job_info(
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=skip_submission_job_info_field,
skip_is_running_tasks_field=skip_is_running_tasks_field,
timeout=timeout,
)
4 changes: 3 additions & 1 deletion python/ray/_private/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def job_table(self):
"""
self._check_connected()

job_table = self.global_state_accessor.get_job_table()
job_table = self.global_state_accessor.get_job_table(
skip_submission_job_info_field=True, skip_is_running_tasks_field=True
)

results = []
for i in range(len(job_table)):
Expand Down
4 changes: 3 additions & 1 deletion python/ray/_private/usage/usage_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ def put_cluster_metadata(gcs_client, *, ray_init_cluster) -> None:
def get_total_num_running_jobs_to_report(gcs_client) -> Optional[int]:
"""Return the total number of running jobs in the cluster excluding internal ones"""
try:
result = gcs_client.get_all_job_info()
result = gcs_client.get_all_job_info(
skip_submission_job_info_field=True, skip_is_running_tasks_field=True
)
total_num_running_jobs = 0
for job_info in result.values():
if not job_info.is_dead and not job_info.config.ray_namespace.startswith(
Expand Down
11 changes: 8 additions & 3 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2907,8 +2907,10 @@ cdef class OldGcsClient:
return result

@_auto_reconnect
def get_all_job_info(self, job_or_submission_id: str = None,
timeout=None) -> Dict[JobID, JobTableData]:
def get_all_job_info(
self, *, job_or_submission_id: str = None, skip_submission_job_info_field=False,
skip_is_running_tasks_field=False, timeout=None
) -> Dict[JobID, JobTableData]:
# Ideally we should use json_format.MessageToDict(job_info),
# but `job_info` is a cpp pb message not a python one.
# Manually converting each and every protobuf field is out of question,
Expand All @@ -2917,6 +2919,8 @@ cdef class OldGcsClient:
c_string c_job_or_submission_id
optional[c_string] c_optional_job_or_submission_id = nullopt
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
CJobTableData c_job_info
c_vector[CJobTableData] c_job_infos
c_vector[c_string] serialized_job_infos
Expand All @@ -2926,7 +2930,8 @@ cdef class OldGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
check_status(self.inner.get().GetAllJobInfo(
c_optional_job_or_submission_id, timeout_ms, c_job_infos))
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, timeout_ms, c_job_infos))
for c_job_info in c_job_infos:
serialized_job_infos.push_back(c_job_info.SerializeAsString())
result = {}
Expand Down
5 changes: 4 additions & 1 deletion python/ray/dashboard/modules/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ async def get_driver_jobs(
jobs with the job id or submission id.
"""
job_infos = await gcs_aio_client.get_all_job_info(
job_or_submission_id=job_or_submission_id, timeout=timeout
job_or_submission_id=job_or_submission_id,
skip_submission_job_info_field=True,
skip_is_running_tasks_field=True,
timeout=timeout,
)
# Sort jobs from GCS to follow convention of returning only last driver
# of submission job.
Expand Down
6 changes: 5 additions & 1 deletion python/ray/dashboard/modules/snapshot/snapshot_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,11 @@ async def _get_job_activity_info(self, timeout: int) -> RayActivityResponse:
# This includes the _ray_internal_dashboard job that gets automatically
# created with every cluster
try:
reply = await self._gcs_aio_client.get_all_job_info(timeout=timeout)
reply = await self._gcs_aio_client.get_all_job_info(
skip_submission_job_info_field=True,
skip_is_running_tasks_field=True,
timeout=timeout,
)

num_active_drivers = 0
latest_job_end_time = 0
Expand Down
9 changes: 7 additions & 2 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,15 @@ cdef extern from "ray/gcs/gcs_client/accessor.h" nogil:
cdef cppclass CJobInfoAccessor "ray::gcs::JobInfoAccessor":
CRayStatus GetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
c_vector[CJobTableData] &result,
int64_t timeout_ms)

CRayStatus AsyncGetAll(
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field,
c_bool skip_is_running_tasks_field,
const MultiItemPyCallback[CJobTableData] &callback,
int64_t timeout_ms)

Expand Down Expand Up @@ -625,8 +629,9 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
CRayStatus GetAllNodeInfo(
int64_t timeout_ms, c_vector[CGcsNodeInfo]& result)
CRayStatus GetAllJobInfo(
const optional[c_string] &job_or_submission_id, int64_t timeout_ms,
c_vector[CJobTableData]& result)
const optional[c_string] &job_or_submission_id,
c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field,
int64_t timeout_ms, c_vector[CJobTableData]& result)
CRayStatus GetAllResourceUsage(
int64_t timeout_ms, c_string& serialized_reply)
CRayStatus RequestClusterResourceConstraint(
Expand Down
17 changes: 14 additions & 3 deletions python/ray/includes/gcs_client.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,16 @@ cdef class NewGcsClient:
#############################################################

def get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Dict[JobID, gcs_pb2.JobTableData]:
cdef c_string c_job_or_submission_id
cdef optional[c_string] c_optional_job_or_submission_id = nullopt
cdef int64_t timeout_ms = round(1000 * timeout) if timeout else -1
cdef c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
cdef c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
cdef CRayStatus status
cdef c_vector[CJobTableData] reply
if job_or_submission_id:
Expand All @@ -447,17 +451,22 @@ cdef class NewGcsClient:
make_optional[c_string](c_job_or_submission_id)
with nogil:
status = self.inner.get().Jobs().GetAll(
c_optional_job_or_submission_id, reply, timeout_ms)
c_optional_job_or_submission_id, c_skip_submission_job_info_field,
c_skip_is_running_tasks_field, reply, timeout_ms)
return raise_or_return((convert_get_all_job_info(status, move(reply))))

def async_get_all_job_info(
self, job_or_submission_id: Optional[str] = None,
self, *, job_or_submission_id: Optional[str] = None,
skip_submission_job_info_field: bool = False,
skip_is_running_tasks_field: bool = False,
timeout: Optional[float] = None
) -> Future[Dict[JobID, gcs_pb2.JobTableData]]:
cdef:
c_string c_job_or_submission_id
optional[c_string] c_optional_job_or_submission_id = nullopt
int64_t timeout_ms = round(1000 * timeout) if timeout else -1
c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field
fut = incremented_fut()
if job_or_submission_id:
c_job_or_submission_id = job_or_submission_id
Expand All @@ -467,6 +476,8 @@ cdef class NewGcsClient:
check_status_timeout_as_rpc_error(
self.inner.get().Jobs().AsyncGetAll(
c_optional_job_or_submission_id,
c_skip_submission_job_info_field,
c_skip_is_running_tasks_field,
MultiItemPyCallback[CJobTableData](
&convert_get_all_job_info,
assign_and_decrement_fut,
Expand Down
3 changes: 2 additions & 1 deletion python/ray/includes/global_state_accessor.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ cdef extern from "ray/gcs/gcs_client/global_state_accessor.h" nogil:
CGlobalStateAccessor(const CGcsClientOptions&)
c_bool Connect()
void Disconnect()
c_vector[c_string] GetAllJobInfo()
c_vector[c_string] GetAllJobInfo(
c_bool skip_submission_job_info_field, c_bool skip_is_running_tasks_field)
CJobID GetNextJobID()
c_vector[c_string] GetAllNodeInfo()
c_vector[c_string] GetAllAvailableResources()
Expand Down
10 changes: 8 additions & 2 deletions python/ray/includes/global_state_accessor.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ cdef class GlobalStateAccessor:
with nogil:
self.inner.get().Disconnect()

def get_job_table(self):
def get_job_table(
self, *, skip_submission_job_info_field=False, skip_is_running_tasks_field=False
):
cdef c_vector[c_string] result
cdef c_bool c_skip_submission_job_info_field = skip_submission_job_info_field
cdef c_bool c_skip_is_running_tasks_field = skip_is_running_tasks_field

with nogil:
result = self.inner.get().GetAllJobInfo()
result = self.inner.get().GetAllJobInfo(
c_skip_submission_job_info_field, c_skip_is_running_tasks_field)
return result

def get_next_job_id(self):
Expand Down
2 changes: 2 additions & 0 deletions src/mock/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class MockJobInfoAccessor : public JobInfoAccessor {
MOCK_METHOD(Status,
AsyncGetAll,
(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms),
(override));
Expand Down
15 changes: 13 additions & 2 deletions src/ray/gcs/gcs_client/accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ Status JobInfoAccessor::AsyncSubscribeAll(
done(status);
}
};
RAY_CHECK_OK(
AsyncGetAll(/*job_or_submission_id=*/std::nullopt, callback, /*timeout_ms=*/-1));
RAY_CHECK_OK(AsyncGetAll(/*job_or_submission_id=*/std::nullopt,
/*skip_submission_job_info_field=*/true,
/*skip_is_running_tasks_field=*/true,
callback,
/*timeout_ms=*/-1));
};
subscribe_operation_ = [this, subscribe](const StatusCallback &done) {
return client_impl_->GetGcsSubscriber().SubscribeAllJobs(subscribe, done);
Expand All @@ -108,11 +111,15 @@ void JobInfoAccessor::AsyncResubscribe() {

Status JobInfoAccessor::AsyncGetAll(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms) {
RAY_LOG(DEBUG) << "Getting all job info.";
RAY_CHECK(callback);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand All @@ -127,9 +134,13 @@ Status JobInfoAccessor::AsyncGetAll(
}

Status JobInfoAccessor::GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms) {
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class JobInfoAccessor {
/// \param callback Callback that will be called after lookup finished.
/// \return Status
virtual Status AsyncGetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
const MultiItemCallback<rpc::JobTableData> &callback,
int64_t timeout_ms);

Expand All @@ -272,6 +274,8 @@ class JobInfoAccessor {
/// \param timeout_ms -1 means infinite.
/// \return Status
virtual Status GetAll(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
std::vector<rpc::JobTableData> &job_data_list,
int64_t timeout_ms);

Expand Down
4 changes: 4 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,13 +503,17 @@ Status PythonGcsClient::GetAllNodeInfo(int64_t timeout_ms,

Status PythonGcsClient::GetAllJobInfo(
const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result) {
grpc::ClientContext context;
PrepareContext(context, timeout_ms);

absl::ReaderMutexLock lock(&mutex_);
rpc::GetAllJobInfoRequest request;
request.set_skip_submission_job_info_field(skip_submission_job_info_field);
request.set_skip_is_running_tasks_field(skip_is_running_tasks_field);
if (job_or_submission_id.has_value()) {
request.set_job_or_submission_id(job_or_submission_id.value());
}
Expand Down
2 changes: 2 additions & 0 deletions src/ray/gcs/gcs_client/gcs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class RAY_EXPORT PythonGcsClient {
Status PinRuntimeEnvUri(const std::string &uri, int expiration_s, int64_t timeout_ms);
Status GetAllNodeInfo(int64_t timeout_ms, std::vector<rpc::GcsNodeInfo> &result);
Status GetAllJobInfo(const std::optional<std::string> &job_or_submission_id,
bool skip_submission_job_info_field,
bool skip_is_running_tasks_field,
int64_t timeout_ms,
std::vector<rpc::JobTableData> &result);
Status GetAllResourceUsage(int64_t timeout_ms, std::string &serialized_reply);
Expand Down
5 changes: 4 additions & 1 deletion src/ray/gcs/gcs_client/global_state_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ void GlobalStateAccessor::Disconnect() {
}
}

std::vector<std::string> GlobalStateAccessor::GetAllJobInfo() {
std::vector<std::string> GlobalStateAccessor::GetAllJobInfo(
bool skip_submission_job_info_field, bool skip_is_running_tasks_field) {
// This method assumes GCS is HA and does not return any error. On GCS down, it
// retries indefinitely.
std::vector<std::string> job_table_data;
Expand All @@ -68,6 +69,8 @@ std::vector<std::string> GlobalStateAccessor::GetAllJobInfo() {
absl::ReaderMutexLock lock(&mutex_);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncGetAll(
/*job_or_submission_id=*/std::nullopt,
skip_submission_job_info_field,
skip_is_running_tasks_field,
TransformForMultiItemCallback<rpc::JobTableData>(job_table_data, promise),
/*timeout_ms=*/-1));
}
Expand Down
4 changes: 3 additions & 1 deletion src/ray/gcs/gcs_client/global_state_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class GlobalStateAccessor {
/// \return All job info. To support multi-language, we serialize each JobTableData and
/// return the serialized string. Where used, it needs to be deserialized with
/// protobuf function.
std::vector<std::string> GetAllJobInfo() ABSL_LOCKS_EXCLUDED(mutex_);
std::vector<std::string> GetAllJobInfo(bool skip_submission_job_info_field = false,
bool skip_is_running_tasks_field = false)
ABSL_LOCKS_EXCLUDED(mutex_);

/// Get next job id from GCS Service.
///
Expand Down
Loading