diff --git a/BUILD.bazel b/BUILD.bazel index c121daaf99c8..83660e6aa8ca 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -513,6 +513,7 @@ ray_cc_library( "@boost//:bimap", "@com_github_grpc_grpc//src/proto/grpc/health/v1:health_proto", "@com_google_absl//absl/container:btree", + "//src/ray/util:thread_checker", ], ) diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.cc b/src/ray/gcs/gcs_server/gcs_job_manager.cc index 23432d1e5171..f68a764f600c 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_job_manager.cc @@ -96,6 +96,8 @@ void GcsJobManager::HandleAddJob(rpc::AddJobRequest request, reply, send_reply_callback = std::move(send_reply_callback)](const Status &status) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (!status.ok()) { RAY_LOG(ERROR) << "Failed to add job, job id = " << job_id << ", driver pid = " << job_table_data.driver_pid(); @@ -136,6 +138,8 @@ void GcsJobManager::MarkJobAsFinished(rpc::JobTableData job_table_data, job_table_data.set_is_dead(true); auto on_done = [this, job_id, job_table_data, done_callback = std::move(done_callback)]( const Status &status) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (!status.ok()) { RAY_LOG(ERROR) << "Failed to mark job state, job id = " << job_id; } else { @@ -176,6 +180,8 @@ void GcsJobManager::HandleMarkJobFinished(rpc::MarkJobFinishedRequest request, job_id, [this, job_id, send_reply](const Status &status, const std::optional &result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + if (status.ok() && result) { MarkJobAsFinished(*result, send_reply); return; @@ -266,6 +272,8 @@ void GcsJobManager::HandleGetAllJobInfo(rpc::GetAllJobInfoRequest request, }; auto on_done = [this, filter_ok, request, reply, send_reply_callback, limit]( const absl::flat_hash_map &&result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + // Internal KV keys for jobs that were submitted via the Ray Job API. std::vector job_api_data_keys; @@ -447,6 +455,8 @@ void GcsJobManager::OnNodeDead(const NodeID &node_id) { << "Node failed, mark all jobs from this node as finished"; auto on_done = [this, node_id](const absl::flat_hash_map &result) { + RAY_CHECK(thread_checker_.IsOnSameThread()); + // If job is not dead and from driver in current node, then mark it as finished for (auto &data : result) { if (!data.second.is_dead() && diff --git a/src/ray/gcs/gcs_server/gcs_job_manager.h b/src/ray/gcs/gcs_server/gcs_job_manager.h index b74b24f3e1d1..95f43c7e27ad 100644 --- a/src/ray/gcs/gcs_server/gcs_job_manager.h +++ b/src/ray/gcs/gcs_server/gcs_job_manager.h @@ -31,6 +31,7 @@ #include "ray/rpc/worker/core_worker_client.h" #include "ray/rpc/worker/core_worker_client_pool.h" #include "ray/util/event.h" +#include "ray/util/thread_checker.h" namespace ray { namespace gcs { @@ -107,6 +108,10 @@ class GcsJobManager : public rpc::JobInfoHandler { void MarkJobAsFinished(rpc::JobTableData job_table_data, std::function done_callback); + // Used to validate invariants for threading; for example, all callbacks are executed on + // the same thread. + ThreadChecker thread_checker_; + // Running Job IDs, used to report metrics. absl::flat_hash_set running_job_ids_;