diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.cc b/src/ray/gcs/gcs_server/gcs_task_manager.cc index 3b326d2a8185..37144bf77d32 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_task_manager.cc @@ -252,12 +252,13 @@ GcsTaskManager::GcsTaskManagerStorage::AddOrReplaceTaskEvent( // Update the events. if (events_by_task.has_task_info() && !existing_events.has_task_info()) { - num_tasks_by_type_[events_by_task.task_info().type()]++; + stats_counter_.Increment( + kTaskTypeToCounterType.at(events_by_task.task_info().type())); } - num_bytes_task_events_ -= existing_events.ByteSizeLong(); + stats_counter_.Decrement(kNumTaskEventsBytesStored, existing_events.ByteSizeLong()); existing_events.MergeFrom(events_by_task); - num_bytes_task_events_ += existing_events.ByteSizeLong(); + stats_counter_.Increment(kNumTaskEventsBytesStored, existing_events.ByteSizeLong()); MarkTaskTreeFailedIfNeeded(task_id, parent_task_id); return absl::nullopt; @@ -267,7 +268,8 @@ GcsTaskManager::GcsTaskManagerStorage::AddOrReplaceTaskEvent( // Bump the task counters by type. if (events_by_task.has_task_info() && events_by_task.attempt_number() == 0) { - num_tasks_by_type_[events_by_task.task_info().type()]++; + stats_counter_.Increment( + kTaskTypeToCounterType.at(events_by_task.task_info().type())); } // If limit enforced, replace one. @@ -280,8 +282,9 @@ GcsTaskManager::GcsTaskManagerStorage::AddOrReplaceTaskEvent( "`RAY_task_events_max_num_task_in_gcs` to a higher value to " "store more."; - num_bytes_task_events_ -= task_events_[next_idx_to_overwrite_].ByteSizeLong(); - num_bytes_task_events_ += events_by_task.ByteSizeLong(); + stats_counter_.Decrement(kNumTaskEventsBytesStored, + task_events_[next_idx_to_overwrite_].ByteSizeLong()); + stats_counter_.Increment(kNumTaskEventsBytesStored, events_by_task.ByteSizeLong()); // Change the underlying storage. auto &to_replaced = task_events_.at(next_idx_to_overwrite_); @@ -341,6 +344,9 @@ GcsTaskManager::GcsTaskManagerStorage::AddOrReplaceTaskEvent( job_to_task_attempt_index_[job_id].insert(task_attempt); task_to_task_attempt_index_[task_id].insert(task_attempt); // Add a new task events. + stats_counter_.Increment(kNumTaskEventsBytesStored, events_by_task.ByteSizeLong()); + stats_counter_.Increment(kNumTaskEventsStored); + task_events_.push_back(std::move(events_by_task)); MarkTaskTreeFailedIfNeeded(task_id, parent_task_id); @@ -351,7 +357,6 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, rpc::GetTaskEventsReply *reply, rpc::SendReplyCallback send_reply_callback) { RAY_LOG(DEBUG) << "Getting task status:" << request.ShortDebugString(); - absl::MutexLock lock(&mutex_); // Select candidate events by indexing. std::vector task_events; @@ -398,10 +403,10 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, // TODO(rickyx): We will need to revisit the data loss semantics, to report data loss // on a single task retry(attempt) rather than the actual events. // https://github.com/ray-project/ray/issues/31280 - reply->set_num_profile_task_events_dropped(total_num_profile_task_events_dropped_ + - num_profile_event_limit); - reply->set_num_status_task_events_dropped(total_num_status_task_events_dropped_ + - num_status_event_limit); + reply->set_num_profile_task_events_dropped( + stats_counter_.Get(kTotalNumProfileTaskEventsDropped) + num_profile_event_limit); + reply->set_num_status_task_events_dropped( + stats_counter_.Get(kTotalNumStatusTaskEventsDropped) + num_status_event_limit); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); return; @@ -410,15 +415,15 @@ void GcsTaskManager::HandleGetTaskEvents(rpc::GetTaskEventsRequest request, void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request, rpc::AddTaskEventDataReply *reply, rpc::SendReplyCallback send_reply_callback) { - absl::MutexLock lock(&mutex_); - // Dispatch to the handler auto data = std::move(request.data()); // Update counters. - total_num_profile_task_events_dropped_ += data.num_profile_task_events_dropped(); - total_num_status_task_events_dropped_ += data.num_status_task_events_dropped(); + stats_counter_.Increment(kTotalNumProfileTaskEventsDropped, + data.num_profile_task_events_dropped()); + stats_counter_.Increment(kTotalNumStatusTaskEventsDropped, + data.num_status_task_events_dropped()); for (auto events_by_task : *data.mutable_events_by_task()) { - total_num_task_events_reported_++; + stats_counter_.Increment(kTotalNumTaskEventsReported); // TODO(rickyx): add logic to handle too many profile events for a single task // attempt. https://github.com/ray-project/ray/issues/31279 @@ -430,11 +435,11 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request // TODO(rickyx): should we un-flatten the status updates into a list of // StatusEvents? so that we could get an accurate number of status change // events being dropped like profile events. - total_num_status_task_events_dropped_++; + stats_counter_.Increment(kTotalNumStatusTaskEventsDropped); } if (replaced_task_events->has_profile_events()) { - total_num_profile_task_events_dropped_ += - replaced_task_events->profile_events().events_size(); + stats_counter_.Increment(kTotalNumProfileTaskEventsDropped, + replaced_task_events->profile_events().events_size()); } } } @@ -444,57 +449,52 @@ void GcsTaskManager::HandleAddTaskEventData(rpc::AddTaskEventDataRequest request } std::string GcsTaskManager::DebugString() { - absl::MutexLock lock(&mutex_); std::ostringstream ss; + auto counters = stats_counter_.GetAll(); ss << "GcsTaskManager: " - << "\n-Total num task events reported: " << total_num_task_events_reported_ + << "\n-Total num task events reported: " << counters[kTotalNumTaskEventsReported] << "\n-Total num status task events dropped: " - << total_num_status_task_events_dropped_ - << "\n-Total num profile events dropped: " << total_num_profile_task_events_dropped_ + << counters[kTotalNumStatusTaskEventsDropped] + << "\n-Total num profile events dropped: " + << counters[kTotalNumProfileTaskEventsDropped] << "\n-Total num bytes of task event stored: " - << 1.0 * task_event_storage_->GetTaskEventsBytes() / 1024 / 1024 << "MiB" - << "\n-Current num of task events stored: " - << task_event_storage_->GetTaskEventsCount() - << "\n-Total num of actor creation tasks: " - << task_event_storage_->num_tasks_by_type_[rpc::TaskType::ACTOR_CREATION_TASK] - << "\n-Total num of actor tasks: " - << task_event_storage_->num_tasks_by_type_[rpc::TaskType::ACTOR_TASK] - << "\n-Total num of normal tasks: " - << task_event_storage_->num_tasks_by_type_[rpc::TaskType::NORMAL_TASK] - << "\n-Total num of driver tasks: " - << task_event_storage_->num_tasks_by_type_[rpc::TaskType::DRIVER_TASK]; + << 1.0 * counters[kNumTaskEventsBytesStored] / 1024 / 1024 << "MiB" + << "\n-Current num of task events stored: " << counters[kNumTaskEventsStored] + << "\n-Total num of actor creation tasks: " << counters[kTotalNumActorCreationTask] + << "\n-Total num of actor tasks: " << counters[kTotalNumActorTask] + << "\n-Total num of normal tasks: " << counters[kTotalNumNormalTask] + << "\n-Total num of driver tasks: " << counters[kTotalNumDriverTask]; return ss.str(); } void GcsTaskManager::RecordMetrics() { - absl::MutexLock lock(&mutex_); + auto counters = stats_counter_.GetAll(); ray::stats::STATS_gcs_task_manager_task_events_reported.Record( - total_num_task_events_reported_); + counters[kTotalNumTaskEventsReported]); ray::stats::STATS_gcs_task_manager_task_events_dropped.Record( - total_num_status_task_events_dropped_, ray::stats::kGcsTaskStatusEventDropped); + counters[kTotalNumStatusTaskEventsDropped], ray::stats::kGcsTaskStatusEventDropped); ray::stats::STATS_gcs_task_manager_task_events_dropped.Record( - total_num_profile_task_events_dropped_, ray::stats::kGcsProfileEventDropped); + counters[kTotalNumProfileTaskEventsDropped], ray::stats::kGcsProfileEventDropped); ray::stats::STATS_gcs_task_manager_task_events_stored.Record( - task_event_storage_->GetTaskEventsCount()); + counters[kNumTaskEventsStored]); ray::stats::STATS_gcs_task_manager_task_events_stored_bytes.Record( - task_event_storage_->GetTaskEventsBytes()); - - if (usage_stats_client_) { - usage_stats_client_->RecordExtraUsageCounter( - usage::TagKey::NUM_ACTOR_CREATION_TASKS, - task_event_storage_->num_tasks_by_type_[rpc::TaskType::ACTOR_CREATION_TASK]); - usage_stats_client_->RecordExtraUsageCounter( - usage::TagKey::NUM_ACTOR_TASKS, - task_event_storage_->num_tasks_by_type_[rpc::TaskType::ACTOR_TASK]); - usage_stats_client_->RecordExtraUsageCounter( - usage::TagKey::NUM_NORMAL_TASKS, - task_event_storage_->num_tasks_by_type_[rpc::TaskType::NORMAL_TASK]); - usage_stats_client_->RecordExtraUsageCounter( - usage::TagKey::NUM_DRIVERS, - task_event_storage_->num_tasks_by_type_[rpc::TaskType::DRIVER_TASK]); + counters[kNumTaskEventsBytesStored]); + + { + absl::MutexLock lock(&mutex_); + if (usage_stats_client_) { + usage_stats_client_->RecordExtraUsageCounter( + usage::TagKey::NUM_ACTOR_CREATION_TASKS, counters[kTotalNumActorCreationTask]); + usage_stats_client_->RecordExtraUsageCounter(usage::TagKey::NUM_ACTOR_TASKS, + counters[kTotalNumActorTask]); + usage_stats_client_->RecordExtraUsageCounter(usage::TagKey::NUM_NORMAL_TASKS, + counters[kTotalNumNormalTask]); + usage_stats_client_->RecordExtraUsageCounter(usage::TagKey::NUM_DRIVERS, + counters[kTotalNumDriverTask]); + } } } @@ -513,7 +513,6 @@ void GcsTaskManager::OnJobFinished(const JobID &job_id, int64_t job_finish_time_ // timer canceled or aborted. return; } - absl::MutexLock lock(&mutex_); // If there are any non-terminated tasks from the job, mark them failed since all // workers associated with the job will be killed. task_event_storage_->MarkTasksFailed(job_id, job_finish_time_ms * 1000 * 1000); diff --git a/src/ray/gcs/gcs_server/gcs_task_manager.h b/src/ray/gcs/gcs_server/gcs_task_manager.h index b628e7d01fb2..52829e5af09c 100644 --- a/src/ray/gcs/gcs_server/gcs_task_manager.h +++ b/src/ray/gcs/gcs_server/gcs_task_manager.h @@ -20,6 +20,7 @@ #include "absl/synchronization/mutex.h" #include "ray/gcs/gcs_client/usage_stats_client.h" #include "ray/rpc/gcs_server/gcs_rpc_server.h" +#include "ray/util/counter_map.h" #include "src/ray/protobuf/gcs.pb.h" namespace ray { @@ -28,6 +29,25 @@ namespace gcs { /// Type alias for a single task attempt, i.e. . using TaskAttempt = std::pair; +enum GcsTaskManagerCounter { + kTotalNumTaskEventsReported, + kTotalNumStatusTaskEventsDropped, + kTotalNumProfileTaskEventsDropped, + kNumTaskEventsBytesStored, + kNumTaskEventsStored, + kTotalNumActorCreationTask, + kTotalNumActorTask, + kTotalNumNormalTask, + kTotalNumDriverTask, +}; + +const absl::flat_hash_map kTaskTypeToCounterType = { + {rpc::TaskType::NORMAL_TASK, kTotalNumNormalTask}, + {rpc::TaskType::ACTOR_CREATION_TASK, kTotalNumActorCreationTask}, + {rpc::TaskType::ACTOR_TASK, kTotalNumActorTask}, + {rpc::TaskType::DRIVER_TASK, kTotalNumDriverTask}, +}; + /// GcsTaskManger is responsible for capturing task states change reported by /// TaskEventBuffer from other components. /// @@ -42,8 +62,9 @@ class GcsTaskManager : public rpc::TaskInfoHandler { public: /// Create a GcsTaskManager. GcsTaskManager() - : task_event_storage_(std::make_unique( - RayConfig::instance().task_events_max_num_task_in_gcs())), + : stats_counter_(), + task_event_storage_(std::make_unique( + RayConfig::instance().task_events_max_num_task_in_gcs(), stats_counter_)), io_service_thread_(std::make_unique([this] { SetThreadName("task_events"); // Keep io_service_ alive. @@ -59,8 +80,7 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// \param send_reply_callback Callback to invoke when sending reply. void HandleAddTaskEventData(rpc::AddTaskEventDataRequest request, rpc::AddTaskEventDataReply *reply, - rpc::SendReplyCallback send_reply_callback) - LOCKS_EXCLUDED(mutex_) override; + rpc::SendReplyCallback send_reply_callback) override; /// Handle GetTaskEvent request. /// @@ -69,14 +89,13 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// \param send_reply_callback Callback to invoke when sending reply. void HandleGetTaskEvents(rpc::GetTaskEventsRequest request, rpc::GetTaskEventsReply *reply, - rpc::SendReplyCallback send_reply_callback) - LOCKS_EXCLUDED(mutex_) override; + rpc::SendReplyCallback send_reply_callback) override; /// Stops the event loop and the thread of the task event handler. /// /// After this is called, no more requests will be handled. /// This function returns when the io thread is joined. - void Stop() LOCKS_EXCLUDED(mutex_); + void Stop(); /// Handler to be called when a job finishes. This marks all non-terminated tasks /// of the job as failed. @@ -93,7 +112,7 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// Return string of debug state. /// /// \return Debug string - std::string DebugString() LOCKS_EXCLUDED(mutex_); + std::string DebugString(); /// Record metrics. void RecordMetrics() LOCKS_EXCLUDED(mutex_); @@ -120,8 +139,9 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// /// \param max_num_task_events Max number of task events stored before replacing older /// ones. - GcsTaskManagerStorage(size_t max_num_task_events) - : max_num_task_events_(max_num_task_events) {} + GcsTaskManagerStorage(size_t max_num_task_events, + CounterMapThreadSafe &stats_counter) + : max_num_task_events_(max_num_task_events), stats_counter_(stats_counter) {} /// Add a new task event or replace an existing task event in the storage. /// @@ -247,21 +267,12 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// could be found or there's data loss. absl::optional GetLatestTaskAttempt(const TaskID &task_id) const; - /// Get the number of task events stored. - size_t GetTaskEventsCount() const { return task_events_.size(); } - - /// Get the total number of bytes of task events stored. - uint64_t GetTaskEventsBytes() const { return num_bytes_task_events_; } - /// Max number of task events allowed in the storage. const size_t max_num_task_events_ = 0; /// A iterator into task_events_ that determines which element to be overwritten. size_t next_idx_to_overwrite_ = 0; - /// Total number of tasks by types, including ones have been evicted/finished. - absl::flat_hash_map num_tasks_by_type_; - /// TODO(rickyx): Refactor this into LRI(least recently inserted) buffer: /// https://github.com/ray-project/ray/issues/31158 /// Current task events stored. @@ -283,9 +294,8 @@ class GcsTaskManager : public rpc::TaskInfoHandler { absl::flat_hash_map> parent_to_children_task_index_; - /// Counter for tracking the size of task event. This assumes tasks events are never - /// removed actively. - uint64_t num_bytes_task_events_ = 0; + /// Reference to the counter map owned by the GcsTaskManager. + CounterMapThreadSafe &stats_counter_; friend class GcsTaskManager; FRIEND_TEST(GcsTaskManagerTest, TestHandleAddTaskEventBasic); @@ -295,20 +305,35 @@ class GcsTaskManager : public rpc::TaskInfoHandler { }; private: - /// Mutex guarding all fields that will be accessed by main_io as well. - absl::Mutex mutex_; + /// Test only + size_t GetTotalNumStatusTaskEventsDropped() { + return stats_counter_.Get(kTotalNumStatusTaskEventsDropped); + } - /// Total number of task events reported. - uint32_t total_num_task_events_reported_ GUARDED_BY(mutex_) = 0; + /// Test only + size_t GetTotalNumProfileTaskEventsDropped() { + return stats_counter_.Get(kTotalNumProfileTaskEventsDropped); + } - /// Total number of status task events dropped on the worker. - uint32_t total_num_status_task_events_dropped_ GUARDED_BY(mutex_) = 0; + /// Test only + size_t GetTotalNumTaskEventsReported() { + return stats_counter_.Get(kTotalNumTaskEventsReported); + } - /// Total number of profile task events dropped on the worker. - uint32_t total_num_profile_task_events_dropped_ GUARDED_BY(mutex_) = 0; + /// Test only + size_t GetNumTaskEventsStored() { return stats_counter_.Get(kNumTaskEventsStored); } - // Pointer to the underlying task events storage. - std::unique_ptr task_event_storage_ GUARDED_BY(mutex_); + // Mutex guarding the usage stats client + absl::Mutex mutex_; + + UsageStatsClient *usage_stats_client_ GUARDED_BY(mutex_) = nullptr; + + /// Counter map for GcsTaskManager stats. + CounterMapThreadSafe stats_counter_; + + // Pointer to the underlying task events storage. This is only accessed from + // the io_service_thread_. Access to it is *not* thread safe. + std::unique_ptr task_event_storage_; /// Its own separate IO service separated from the main service. instrumented_io_context io_service_; @@ -319,8 +344,6 @@ class GcsTaskManager : public rpc::TaskInfoHandler { /// Timer for delay functions. boost::asio::deadline_timer timer_; - UsageStatsClient *usage_stats_client_ GUARDED_BY(mutex_) = nullptr; - FRIEND_TEST(GcsTaskManagerTest, TestHandleAddTaskEventBasic); FRIEND_TEST(GcsTaskManagerTest, TestMergeTaskEventsSameTaskAttempt); FRIEND_TEST(GcsTaskManagerMemoryLimitedTest, TestLimitTaskEvents); diff --git a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc index 294b8a6a158a..f1cebaf39ae9 100644 --- a/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_task_manager_test.cc @@ -83,12 +83,17 @@ class GcsTaskManagerTest : public ::testing::Test { std::promise promise; request.mutable_data()->CopyFrom(events_data); - task_manager->HandleAddTaskEventData( - request, - &reply, - [&promise](Status, std::function, std::function) { - promise.set_value(true); - }); + // Dispatch so that it runs in GcsTaskManager's io service. + task_manager->GetIoContext().dispatch( + [this, &promise, &request, &reply]() { + task_manager->HandleAddTaskEventData( + request, + &reply, + [&promise](Status, std::function, std::function) { + promise.set_value(true); + }); + }, + "SyncAddTaskEventData"); promise.get_future().get(); @@ -120,13 +125,16 @@ class GcsTaskManagerTest : public ::testing::Test { } request.set_exclude_driver(exclude_driver); - - task_manager->HandleGetTaskEvents( - request, - &reply, - [&promise](Status, std::function, std::function) { - promise.set_value(true); - }); + task_manager->GetIoContext().dispatch( + [this, &promise, &request, &reply]() { + task_manager->HandleGetTaskEvents( + request, + &reply, + [&promise](Status, std::function, std::function) { + promise.set_value(true); + }); + }, + "SyncGetTaskEvents"); promise.get_future().get(); @@ -248,12 +256,11 @@ TEST_F(GcsTaskManagerTest, TestHandleAddTaskEventBasic) { // Assert on actual data. { - absl::MutexLock lock(&task_manager->mutex_); EXPECT_EQ(task_manager->task_event_storage_->task_events_.size(), num_task_events); - EXPECT_EQ(task_manager->total_num_task_events_reported_, num_task_events); - EXPECT_EQ(task_manager->total_num_profile_task_events_dropped_, + EXPECT_EQ(task_manager->GetTotalNumTaskEventsReported(), num_task_events); + EXPECT_EQ(task_manager->GetTotalNumProfileTaskEventsDropped(), num_profile_events_dropped); - EXPECT_EQ(task_manager->total_num_status_task_events_dropped_, + EXPECT_EQ(task_manager->GetTotalNumStatusTaskEventsDropped(), num_status_events_dropped); } } @@ -274,7 +281,6 @@ TEST_F(GcsTaskManagerTest, TestMergeTaskEventsSameTaskAttempt) { // Assert on actual data { - absl::MutexLock lock(&task_manager->mutex_); EXPECT_EQ(task_manager->task_event_storage_->task_events_.size(), 1); // Assert on events auto task_events = task_manager->task_event_storage_->task_events_[0]; @@ -643,10 +649,8 @@ TEST_F(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak) { } { - absl::MutexLock lock(&task_manager->mutex_); - EXPECT_EQ( - task_manager->task_event_storage_->num_tasks_by_type_[rpc::TaskType::NORMAL_TASK], - task_ids.size()); + EXPECT_EQ(task_manager->task_event_storage_->stats_counter_.Get(kTotalNumNormalTask), + task_ids.size()); } // Evict all of them with tasks with single attempt, no parent, same job. @@ -668,11 +672,9 @@ TEST_F(GcsTaskManagerMemoryLimitedTest, TestIndexNoLeak) { } // Assert on the indexes and the storage { - absl::MutexLock lock(&task_manager->mutex_); EXPECT_EQ(task_manager->task_event_storage_->task_events_.size(), num_limit); - EXPECT_EQ( - task_manager->task_event_storage_->num_tasks_by_type_[rpc::TaskType::NORMAL_TASK], - task_ids.size() + num_limit); + EXPECT_EQ(task_manager->task_event_storage_->stats_counter_.Get(kTotalNumNormalTask), + task_ids.size() + num_limit); // No task has parent. EXPECT_EQ(task_manager->task_event_storage_->parent_to_children_task_index_.size(), 0); @@ -729,9 +731,8 @@ TEST_F(GcsTaskManagerMemoryLimitedTest, TestLimitTaskEvents) { // Assert on actual data. { - absl::MutexLock lock(&task_manager->mutex_); - EXPECT_EQ(task_manager->task_event_storage_->task_events_.size(), num_limit); - EXPECT_EQ(task_manager->total_num_task_events_reported_, num_batch1 + num_batch2); + EXPECT_EQ(task_manager->GetNumTaskEventsStored(), num_limit); + EXPECT_EQ(task_manager->GetTotalNumTaskEventsReported(), num_batch1 + num_batch2); std::sort(expected_events.begin(), expected_events.end(), SortByTaskAttempt); auto actual_events = task_manager->task_event_storage_->task_events_; @@ -743,9 +744,9 @@ TEST_F(GcsTaskManagerMemoryLimitedTest, TestLimitTaskEvents) { } // Assert on drop counts. - EXPECT_EQ(task_manager->total_num_status_task_events_dropped_, + EXPECT_EQ(task_manager->GetTotalNumStatusTaskEventsDropped(), num_status_events_to_drop + num_status_events_dropped_on_worker); - EXPECT_EQ(task_manager->total_num_profile_task_events_dropped_, + EXPECT_EQ(task_manager->GetTotalNumProfileTaskEventsDropped(), num_profile_events_to_drop + num_profile_events_dropped_on_worker); } } diff --git a/src/ray/util/counter_map.h b/src/ray/util/counter_map.h index 2a758c99f207..3351940e8e40 100644 --- a/src/ray/util/counter_map.h +++ b/src/ray/util/counter_map.h @@ -18,6 +18,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" #include "ray/util/logging.h" /// \class CounterMap @@ -118,9 +119,79 @@ class CounterMap { } } + /// Return a snapshot of all the counters. + absl::flat_hash_map GetAll() const { return counters_; } + private: absl::flat_hash_map counters_; absl::flat_hash_set pending_changes_; std::function on_change_; size_t total_ = 0; }; + +/// \class A thread safe version of CounterMap with mutex guarded all methods. +template +class CounterMapThreadSafe { + public: + CounterMapThreadSafe() = default; + + void SetOnChangeCallback(std::function on_change) + LOCKS_EXCLUDED(mutex_) { + absl::WriterMutexLock lock(&mutex_); + counter_map_.SetOnChangeCallback(std::move(on_change)); + } + + void FlushOnChangeCallbacks() LOCKS_EXCLUDED(mutex_) { + absl::WriterMutexLock lock(&mutex_); + counter_map_.FlushOnChangeCallbacks(); + } + + void Increment(const K &key, int64_t val = 1) LOCKS_EXCLUDED(mutex_) { + absl::WriterMutexLock lock(&mutex_); + counter_map_.Increment(key, val); + } + + void Decrement(const K &key, int64_t val = 1) LOCKS_EXCLUDED(mutex_) { + absl::WriterMutexLock lock(&mutex_); + counter_map_.Decrement(key, val); + } + + int64_t Get(const K &key) { + absl::ReaderMutexLock lock(&mutex_); + return counter_map_.Get(key); + } + + void Swap(const K &old_key, const K &new_key, int64_t val = 1) LOCKS_EXCLUDED(mutex_) { + absl::WriterMutexLock lock(&mutex_); + counter_map_.Swap(old_key, new_key, val); + } + + size_t Size() { + absl::ReaderMutexLock lock(&mutex_); + return counter_map_.Size(); + } + + size_t Total() { + absl::ReaderMutexLock lock(&mutex_); + return counter_map_.Total(); + } + + size_t NumPendingCallbacks() { + absl::ReaderMutexLock lock(&mutex_); + return counter_map_.NumPendingCallbacks(); + } + + void ForEachEntry(std::function callback) { + absl::ReaderMutexLock lock(&mutex_); + counter_map_.ForEachEntry(std::move(callback)); + } + + absl::flat_hash_map GetAll() { + absl::ReaderMutexLock lock(&mutex_); + return counter_map_.GetAll(); + } + + private: + absl::Mutex mutex_; + CounterMap counter_map_; +};