diff --git a/BUILD.bazel b/BUILD.bazel index 10d3b0e621fc..7f0884f348aa 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1651,9 +1651,21 @@ cc_test( ], copts = COPTS, tags = ["team:core"], - target_compatible_with = [ - "@platforms//os:linux", + deps = [ + ":ray_common", + ":raylet_lib", + "@com_google_googletest//:gtest_main", + ], +) + +cc_test( + name = "worker_killing_policy_group_by_owner_test", + size = "small", + srcs = [ + "src/ray/raylet/worker_killing_policy_group_by_owner_test.cc", ], + copts = COPTS, + tags = ["team:core"], deps = [ ":ray_common", ":raylet_lib", diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index a50b84b41281..04373d1caf6d 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -99,6 +99,9 @@ RAY_CONFIG(uint64_t, task_failure_entry_ttl_ms, 15 * 60 * 1000) /// that is not related to running out of memory. Retries indefinitely if the value is -1. RAY_CONFIG(uint64_t, task_oom_retries, 15) +/// The worker killing policy to use, as defined in worker_killing_policy.h. +RAY_CONFIG(std::string, worker_killing_policy, "retriable_lifo") + /// If the raylet fails to get agent info, we will retry after this interval. RAY_CONFIG(uint64_t, raylet_get_agent_info_interval_ms, 1) diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index dce51d440e4d..77cf389381af 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -283,6 +283,8 @@ NodeManager::NodeManager(instrumented_io_context &io_service, next_resource_seq_no_(0), ray_syncer_(io_service_, self_node_id_.Binary()), ray_syncer_service_(ray_syncer_), + worker_killing_policy_( + CreateWorkerKillingPolicy(RayConfig::instance().worker_killing_policy())), memory_monitor_(std::make_unique( io_service, RayConfig::instance().memory_usage_threshold(), @@ -2869,9 +2871,8 @@ MemoryUsageRefreshCallback NodeManager::CreateMemoryUsageRefreshCallback() { << "idle worker are occupying most of the memory."; return; } - RetriableLIFOWorkerKillingPolicy worker_killing_policy; auto worker_to_kill_and_should_retry = - worker_killing_policy.SelectWorkerToKill(workers, system_memory); + worker_killing_policy_->SelectWorkerToKill(workers, system_memory); auto worker_to_kill = worker_to_kill_and_should_retry.first; bool should_retry = worker_to_kill_and_should_retry.second; if (worker_to_kill == nullptr) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 6816d58493bd..e470ac218e6d 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -46,6 +46,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/bundle_spec.h" #include "ray/raylet/placement_group_resource_manager.h" +#include "ray/raylet/worker_killing_policy.h" // clang-format on namespace ray { @@ -833,6 +834,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// RaySyncerService for gRPC syncer::RaySyncerService ray_syncer_service_; + /// The Policy for selecting the worker to kill when the node runs out of memory. + std::shared_ptr worker_killing_policy_; + /// Monitors and reports node memory usage and whether it is above threshold. std::unique_ptr memory_monitor_; }; diff --git a/src/ray/raylet/test/util.h b/src/ray/raylet/test/util.h index 55515fbe6aef..5007a9f491e1 100644 --- a/src/ray/raylet/test/util.h +++ b/src/ray/raylet/test/util.h @@ -24,7 +24,8 @@ class MockWorker : public WorkerInterface { : worker_id_(worker_id), port_(port), is_detached_actor_(false), - runtime_env_hash_(runtime_env_hash) {} + runtime_env_hash_(runtime_env_hash), + job_id_(JobID::FromInt(859)) {} WorkerID WorkerId() const override { return worker_id_; } @@ -34,16 +35,14 @@ class MockWorker : public WorkerInterface { void SetOwnerAddress(const rpc::Address &address) override { address_ = address; } - void AssignTaskId(const TaskID &task_id) override {} + void AssignTaskId(const TaskID &task_id) override { task_id_ = task_id; } void SetAssignedTask(const RayTask &assigned_task) override { task_ = assigned_task; - task_assign_time_ = std::chrono::steady_clock::now(); + task_assign_time_ = absl::Now(); }; - const std::chrono::steady_clock::time_point GetAssignedTaskTime() const override { - return task_assign_time_; - }; + absl::Time GetAssignedTaskTime() const override { return task_assign_time_; }; const std::string IpAddress() const override { return address_.ip_address(); } @@ -95,10 +94,7 @@ class MockWorker : public WorkerInterface { return -1; } void SetAssignedPort(int port) override { RAY_CHECK(false) << "Method unused"; } - const TaskID &GetAssignedTaskId() const override { - RAY_CHECK(false) << "Method unused"; - return TaskID::Nil(); - } + const TaskID &GetAssignedTaskId() const override { return task_id_; } bool AddBlockedTaskId(const TaskID &task_id) override { RAY_CHECK(false) << "Method unused"; return false; @@ -112,10 +108,7 @@ class MockWorker : public WorkerInterface { auto *t = new std::unordered_set(); return *t; } - const JobID &GetAssignedJobId() const override { - RAY_CHECK(false) << "Method unused"; - return JobID::Nil(); - } + const JobID &GetAssignedJobId() const override { return job_id_; } int GetRuntimeEnvHash() const override { return runtime_env_hash_; } void AssignActorId(const ActorID &actor_id) override { RAY_CHECK(false) << "Method unused"; @@ -189,8 +182,10 @@ class MockWorker : public WorkerInterface { BundleID bundle_id_; bool blocked_ = false; RayTask task_; - std::chrono::steady_clock::time_point task_assign_time_; + absl::Time task_assign_time_; int runtime_env_hash_; + TaskID task_id_; + JobID job_id_; }; } // namespace raylet diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index a3654fba8831..f44c30fc4793 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -17,6 +17,8 @@ #include #include "absl/memory/memory.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" #include "gtest/gtest_prod.h" #include "ray/common/client_connection.h" #include "ray/common/id.h" @@ -109,7 +111,7 @@ class WorkerInterface { virtual bool IsAvailableForScheduling() const = 0; /// Time when the last task was assigned to this worker. - virtual const std::chrono::steady_clock::time_point GetAssignedTaskTime() const = 0; + virtual absl::Time GetAssignedTaskTime() const = 0; protected: virtual void SetStartupToken(StartupToken startup_token) = 0; @@ -213,12 +215,10 @@ class Worker : public WorkerInterface { void SetAssignedTask(const RayTask &assigned_task) { assigned_task_ = assigned_task; - task_assign_time_ = std::chrono::steady_clock::now(); - }; + task_assign_time_ = absl::Now(); + } - const std::chrono::steady_clock::time_point GetAssignedTaskTime() const { - return task_assign_time_; - }; + absl::Time GetAssignedTaskTime() const { return task_assign_time_; }; bool IsRegistered() { return rpc_client_ != nullptr; } @@ -298,7 +298,7 @@ class Worker : public WorkerInterface { /// RayTask being assigned to this worker. RayTask assigned_task_; /// Time when the last task was assigned to this worker. - std::chrono::steady_clock::time_point task_assign_time_; + absl::Time task_assign_time_; /// If true, a RPC need to be sent to notify the worker about GCS restarting. bool notify_gcs_restarted_ = false; }; diff --git a/src/ray/raylet/worker_killing_policy.cc b/src/ray/raylet/worker_killing_policy.cc index f5e738f0193d..7fe218224a89 100644 --- a/src/ray/raylet/worker_killing_policy.cc +++ b/src/ray/raylet/worker_killing_policy.cc @@ -19,6 +19,7 @@ #include "ray/common/asio/instrumented_io_context.h" #include "ray/common/asio/periodical_runner.h" #include "ray/raylet/worker.h" +#include "ray/raylet/worker_killing_policy_group_by_owner.h" #include "ray/raylet/worker_pool.h" namespace ray { @@ -76,10 +77,12 @@ std::string WorkerKillingPolicy::WorkersDebugString( RAY_LOG_EVERY_MS(INFO, 60000) << "Can't find memory usage for PID, reporting zero. PID: " << pid; } - result << "Worker " << index << ": task assigned time counter " - << worker->GetAssignedTaskTime().time_since_epoch().count() << " worker id " - << worker->WorkerId() << " memory used " << used_memory << " task spec " + result << "Worker " << index << ": task assigned time " + << absl::FormatTime(worker->GetAssignedTaskTime(), absl::UTCTimeZone()) + << " worker id " << worker->WorkerId() << " memory used " << used_memory + << " task spec " << worker->GetAssignedTask().GetTaskSpecification().DebugString() << "\n"; + index += 1; if (index > num_workers) { break; @@ -88,6 +91,22 @@ std::string WorkerKillingPolicy::WorkersDebugString( return result.str(); } +std::shared_ptr CreateWorkerKillingPolicy( + std::string killing_policy_str) { + if (killing_policy_str == kLifoPolicy) { + RAY_LOG(INFO) << "Running RetriableLIFO policy."; + return std::make_shared(); + } else if (killing_policy_str == kGroupByOwner) { + RAY_LOG(INFO) << "Running GroupByOwner policy."; + return std::make_shared(); + } else { + RAY_LOG(ERROR) + << killing_policy_str + << " is an invalid killing policy. Defaulting to RetriableLIFO policy."; + return std::make_shared(); + } +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_killing_policy.h b/src/ray/raylet/worker_killing_policy.h index 5f4f0901fa71..d8e1ef742da9 100644 --- a/src/ray/raylet/worker_killing_policy.h +++ b/src/ray/raylet/worker_killing_policy.h @@ -26,6 +26,9 @@ namespace ray { namespace raylet { +constexpr char kLifoPolicy[] = "retriable_lifo"; +constexpr char kGroupByOwner[] = "group_by_owner"; + /// Provides the policy on which worker to prioritize killing. class WorkerKillingPolicy { public: @@ -65,6 +68,9 @@ class RetriableLIFOWorkerKillingPolicy : public WorkerKillingPolicy { const MemorySnapshot &system_memory) const; }; +std::shared_ptr CreateWorkerKillingPolicy( + std::string killing_policy_str); + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_killing_policy_group_by_owner.cc b/src/ray/raylet/worker_killing_policy_group_by_owner.cc new file mode 100644 index 000000000000..e455808ab3a1 --- /dev/null +++ b/src/ray/raylet/worker_killing_policy_group_by_owner.cc @@ -0,0 +1,174 @@ +// Copyright 2022 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/worker_killing_policy_group_by_owner.h" + +#include + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/time/time.h" +#include "ray/common/asio/instrumented_io_context.h" +#include "ray/common/asio/periodical_runner.h" +#include "ray/raylet/worker.h" +#include "ray/raylet/worker_killing_policy.h" +#include "ray/raylet/worker_pool.h" + +namespace ray { + +namespace raylet { + +GroupByOwnerIdWorkerKillingPolicy::GroupByOwnerIdWorkerKillingPolicy() {} + +const std::pair, bool> +GroupByOwnerIdWorkerKillingPolicy::SelectWorkerToKill( + const std::vector> &workers, + const MemorySnapshot &system_memory) const { + if (workers.empty()) { + RAY_LOG_EVERY_MS(INFO, 5000) << "Worker list is empty. Nothing can be killed"; + return std::make_pair(nullptr, /*should retry*/ false); + } + + TaskID non_retriable_owner_id = TaskID::Nil(); + std::unordered_map group_map; + for (auto worker : workers) { + bool retriable = worker->GetAssignedTask().GetTaskSpecification().IsRetriable(); + TaskID owner_id = + retriable ? worker->GetAssignedTask().GetTaskSpecification().ParentTaskId() + : non_retriable_owner_id; + + auto it = group_map.find(owner_id); + + if (it == group_map.end()) { + Group group(owner_id, retriable); + group.AddToGroup(worker); + group_map.emplace(owner_id, std::move(group)); + } else { + auto &group = it->second; + group.AddToGroup(worker); + } + } + + std::vector sorted; + for (auto it = group_map.begin(); it != group_map.end(); ++it) { + sorted.push_back(it->second); + } + + /// Prioritizes killing groups that are retriable, else it picks the largest group, + /// else it picks the newest group. + std::sort( + sorted.begin(), sorted.end(), [](const Group &left, const Group &right) -> bool { + int left_retriable = left.IsRetriable() ? 0 : 1; + int right_retriable = right.IsRetriable() ? 0 : 1; + + if (left_retriable == right_retriable) { + if (left.GetAllWorkers().size() == right.GetAllWorkers().size()) { + return left.GetAssignedTaskTime() > right.GetAssignedTaskTime(); + } + return left.GetAllWorkers().size() > right.GetAllWorkers().size(); + } + return left_retriable < right_retriable; + }); + + Group selected_group = sorted.front(); + bool should_retry = + selected_group.GetAllWorkers().size() > 1 && selected_group.IsRetriable(); + auto worker_to_kill = selected_group.SelectWorkerToKill(); + + RAY_LOG(INFO) << "Sorted list of tasks based on the policy:\n" + << PolicyDebugString(sorted, system_memory); + + return std::make_pair(worker_to_kill, should_retry); +} + +std::string GroupByOwnerIdWorkerKillingPolicy::PolicyDebugString( + const std::vector &groups, const MemorySnapshot &system_memory) { + std::stringstream result; + int32_t group_index = 0; + for (auto &group : groups) { + result << "Tasks (retriable: " << group.IsRetriable() + << ") (parent task id: " << group.OwnerId() << ") (Earliest assigned time: " + << absl::FormatTime(group.GetAssignedTaskTime(), absl::UTCTimeZone()) + << "):\n"; + + int64_t worker_index = 0; + for (auto &worker : group.GetAllWorkers()) { + auto pid = worker->GetProcess().GetId(); + int64_t used_memory = 0; + const auto pid_entry = system_memory.process_used_bytes.find(pid); + if (pid_entry != system_memory.process_used_bytes.end()) { + used_memory = pid_entry->second; + } else { + RAY_LOG_EVERY_MS(INFO, 60000) + << "Can't find memory usage for PID, reporting zero. PID: " << pid; + } + result << "Task assigned time " + << absl::FormatTime(worker->GetAssignedTaskTime(), absl::UTCTimeZone()) + << " worker id " << worker->WorkerId() << " memory used " << used_memory + << " task spec " + << worker->GetAssignedTask().GetTaskSpecification().DebugString() << "\n"; + + worker_index += 1; + if (worker_index > 10) { + break; + } + } + + group_index += 1; + if (group_index > 10) { + break; + } + } + + return result.str(); +} + +const TaskID &Group::OwnerId() const { return owner_id_; } + +const bool Group::IsRetriable() const { return retriable_; } + +const absl::Time Group::GetAssignedTaskTime() const { return earliest_task_time_; } + +void Group::AddToGroup(std::shared_ptr worker) { + if (worker->GetAssignedTaskTime() < earliest_task_time_) { + earliest_task_time_ = worker->GetAssignedTaskTime(); + } + bool retriable = worker->GetAssignedTask().GetTaskSpecification().IsRetriable(); + RAY_CHECK_EQ(retriable_, retriable); + workers_.push_back(worker); +} + +const std::shared_ptr Group::SelectWorkerToKill() const { + RAY_CHECK(!workers_.empty()); + std::vector> sorted(workers_.begin(), workers_.end()); + + std::sort(sorted.begin(), + sorted.end(), + [](std::shared_ptr const &left, + std::shared_ptr const &right) -> bool { + return left->GetAssignedTaskTime() > right->GetAssignedTaskTime(); + }); + + return sorted.front(); +} + +const std::vector> Group::GetAllWorkers() const { + return workers_; +} + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/worker_killing_policy_group_by_owner.h b/src/ray/raylet/worker_killing_policy_group_by_owner.h new file mode 100644 index 000000000000..a2c79fd60207 --- /dev/null +++ b/src/ray/raylet/worker_killing_policy_group_by_owner.h @@ -0,0 +1,100 @@ +// Copyright 2022 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "ray/common/memory_monitor.h" +#include "ray/raylet/worker.h" +#include "ray/raylet/worker_killing_policy.h" + +namespace ray { + +namespace raylet { + +/// Key groups on its owner id. For non-retriable task the owner id is itself, +/// Since non-retriable task forms its own group. +struct GroupKey { + GroupKey(const TaskID &owner_id) : owner_id(owner_id) {} + const TaskID &owner_id; +}; + +struct Group { + public: + Group(const TaskID &owner_id, bool retriable) + : owner_id_(owner_id), retriable_(retriable) {} + + /// The parent task id of the tasks belonging to this group + const TaskID &OwnerId() const; + + /// Whether tasks in this group are retriable. + const bool IsRetriable() const; + + /// Gets the task time of the earliest task of this group, to be + /// used for group priority. + const absl::Time GetAssignedTaskTime() const; + + /// Returns the worker to be killed in this group, in LIFO order. + const std::shared_ptr SelectWorkerToKill() const; + + /// Tasks belonging to this group. + const std::vector> GetAllWorkers() const; + + /// Adds worker that the task belongs to to the group. + void AddToGroup(std::shared_ptr worker); + + private: + /// Tasks belonging to this group. + std::vector> workers_; + + /// The earliest creation time of the tasks. + absl::Time earliest_task_time_ = absl::Now(); + + /// The owner id shared by tasks of this group. + /// TODO(clarng): make this const and implement move / swap. + TaskID owner_id_; + + /// Whether the tasks are retriable. + /// TODO(clarng): make this const and implement move / swap. + bool retriable_; +}; + +/// Groups task by its owner id. Non-retriable task (whether it be task or actor) forms +/// its own group. Prioritizes killing groups that are retriable first, else it picks the +/// largest group, else it picks the newest group. The "age" of a group is based on the +/// time of its earliest submitted task. When a group is selected for killing it selects +/// the last submitted task. +/// +/// When selecting a worker / task to be killed, it will set the task to-be-killed to be +/// non-retriable if it is the last member of the group, and is retriable otherwise. +class GroupByOwnerIdWorkerKillingPolicy : public WorkerKillingPolicy { + public: + GroupByOwnerIdWorkerKillingPolicy(); + const std::pair, bool> SelectWorkerToKill( + const std::vector> &workers, + const MemorySnapshot &system_memory) const; + + private: + /// Creates the debug string of the groups created by the policy. + static std::string PolicyDebugString(const std::vector &groups, + const MemorySnapshot &system_memory); +}; + +} // namespace raylet + +} // namespace ray diff --git a/src/ray/raylet/worker_killing_policy_group_by_owner_test.cc b/src/ray/raylet/worker_killing_policy_group_by_owner_test.cc new file mode 100644 index 000000000000..fcde36e326e2 --- /dev/null +++ b/src/ray/raylet/worker_killing_policy_group_by_owner_test.cc @@ -0,0 +1,233 @@ +// Copyright 2022 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ray/raylet/worker_killing_policy_group_by_owner.h" + +#include "gtest/gtest.h" +#include "ray/common/task/task_spec.h" +#include "ray/raylet/test/util.h" +#include "ray/raylet/worker_killing_policy.h" + +namespace ray { + +namespace raylet { + +class WorkerKillingGroupByOwnerTest : public ::testing::Test { + protected: + instrumented_io_context io_context_; + int32_t port_ = 2389; + JobID job_id_ = JobID::FromInt(75); + bool should_retry_ = true; + bool should_not_retry_ = false; + int32_t no_retry_ = 0; + int32_t has_retry_ = 1; + GroupByOwnerIdWorkerKillingPolicy worker_killing_policy_; + + std::shared_ptr CreateActorCreationWorker(TaskID owner_id, + int32_t max_restarts) { + rpc::TaskSpec message; + message.set_task_id(TaskID::FromRandom(job_id_).Binary()); + message.set_parent_task_id(owner_id.Binary()); + message.mutable_actor_creation_task_spec()->set_max_actor_restarts(max_restarts); + message.set_type(ray::rpc::TaskType::ACTOR_CREATION_TASK); + TaskSpecification task_spec(message); + RayTask task(task_spec); + auto worker = std::make_shared(ray::WorkerID::FromRandom(), port_); + worker->SetAssignedTask(task); + worker->AssignTaskId(task.GetTaskSpecification().TaskId()); + return worker; + } + + std::shared_ptr CreateTaskWorker(TaskID owner_id, + int32_t max_retries) { + rpc::TaskSpec message; + message.set_task_id(TaskID::FromRandom(job_id_).Binary()); + message.set_parent_task_id(owner_id.Binary()); + message.set_max_retries(max_retries); + message.set_type(ray::rpc::TaskType::NORMAL_TASK); + TaskSpecification task_spec(message); + RayTask task(task_spec); + auto worker = std::make_shared(ray::WorkerID::FromRandom(), port_); + worker->SetAssignedTask(task); + worker->AssignTaskId(task.GetTaskSpecification().TaskId()); + return worker; + } +}; + +TEST_F(WorkerKillingGroupByOwnerTest, TestEmptyWorkerPoolSelectsNullWorker) { + std::vector> workers; + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + ASSERT_TRUE(worker_to_kill == nullptr); +} + +TEST_F(WorkerKillingGroupByOwnerTest, TestLastWorkerInGroupShouldNotRetry) { + std::vector> workers; + + auto owner_id = TaskID::ForDriverTask(job_id_); + auto first_submitted = + WorkerKillingGroupByOwnerTest::CreateActorCreationWorker(owner_id, has_retry_); + auto second_submitted = + WorkerKillingGroupByOwnerTest::CreateTaskWorker(owner_id, has_retry_); + + workers.push_back(first_submitted); + workers.push_back(second_submitted); + + std::vector, bool>> expected; + expected.push_back(std::make_pair(second_submitted, should_retry_)); + expected.push_back(std::make_pair(first_submitted, should_not_retry_)); + + for (const auto &entry : expected) { + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + bool retry = worker_to_kill_and_should_retry_.second; + ASSERT_EQ(worker_to_kill->WorkerId(), entry.first->WorkerId()); + ASSERT_EQ(retry, entry.second); + workers.erase(std::remove(workers.begin(), workers.end(), worker_to_kill), + workers.end()); + } +} + +TEST_F(WorkerKillingGroupByOwnerTest, TestNonRetriableBelongsToItsOwnGroupAndLIFOKill) { + auto owner_id = TaskID::ForDriverTask(job_id_); + + std::vector> workers; + auto first_submitted = + WorkerKillingGroupByOwnerTest::CreateActorCreationWorker(owner_id, no_retry_); + auto second_submitted = + WorkerKillingGroupByOwnerTest::CreateTaskWorker(owner_id, no_retry_); + workers.push_back(first_submitted); + workers.push_back(second_submitted); + + std::vector, bool>> expected; + expected.push_back(std::make_pair(second_submitted, should_not_retry_)); + + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + bool retry = worker_to_kill_and_should_retry_.second; + ASSERT_EQ(worker_to_kill->WorkerId(), second_submitted->WorkerId()); + ASSERT_EQ(retry, should_not_retry_); +} + +TEST_F(WorkerKillingGroupByOwnerTest, TestGroupSortedByGroupSizeThenFirstSubmittedTask) { + auto first_group_owner_id = TaskID::FromRandom(job_id_); + auto second_group_owner_id = TaskID::FromRandom(job_id_); + + std::vector> workers; + auto first_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + first_group_owner_id, has_retry_); + auto second_submitted = + WorkerKillingGroupByOwnerTest::CreateTaskWorker(second_group_owner_id, has_retry_); + auto third_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + second_group_owner_id, has_retry_); + auto fourth_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + second_group_owner_id, has_retry_); + auto fifth_submitted = + WorkerKillingGroupByOwnerTest::CreateTaskWorker(first_group_owner_id, has_retry_); + auto sixth_submitted = + WorkerKillingGroupByOwnerTest::CreateTaskWorker(first_group_owner_id, has_retry_); + workers.push_back(first_submitted); + workers.push_back(second_submitted); + workers.push_back(third_submitted); + workers.push_back(fourth_submitted); + workers.push_back(fifth_submitted); + workers.push_back(sixth_submitted); + + std::vector, bool>> expected; + expected.push_back(std::make_pair(fourth_submitted, should_retry_)); + expected.push_back(std::make_pair(sixth_submitted, should_retry_)); + expected.push_back(std::make_pair(third_submitted, should_retry_)); + expected.push_back(std::make_pair(fifth_submitted, should_retry_)); + expected.push_back(std::make_pair(second_submitted, should_not_retry_)); + expected.push_back(std::make_pair(first_submitted, should_not_retry_)); + + for (const auto &entry : expected) { + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + bool retry = worker_to_kill_and_should_retry_.second; + ASSERT_EQ(worker_to_kill->WorkerId(), entry.first->WorkerId()); + ASSERT_EQ(retry, entry.second); + workers.erase(std::remove(workers.begin(), workers.end(), worker_to_kill), + workers.end()); + } +} + +TEST_F(WorkerKillingGroupByOwnerTest, TestGroupSortedByRetriableLifo) { + std::vector> workers; + auto first_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + TaskID::FromRandom(job_id_), has_retry_); + auto second_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + TaskID::FromRandom(job_id_), has_retry_); + auto third_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + TaskID::FromRandom(job_id_), no_retry_); + workers.push_back(first_submitted); + workers.push_back(second_submitted); + workers.push_back(third_submitted); + + std::vector, bool>> expected; + expected.push_back(std::make_pair(second_submitted, should_not_retry_)); + expected.push_back(std::make_pair(first_submitted, should_not_retry_)); + expected.push_back(std::make_pair(third_submitted, should_not_retry_)); + + for (const auto &entry : expected) { + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + bool retry = worker_to_kill_and_should_retry_.second; + ASSERT_EQ(worker_to_kill->WorkerId(), entry.first->WorkerId()); + ASSERT_EQ(retry, entry.second); + workers.erase(std::remove(workers.begin(), workers.end(), worker_to_kill), + workers.end()); + } +} + +TEST_F(WorkerKillingGroupByOwnerTest, + TestMultipleNonRetriableTaskSameGroupAndNotRetried) { + std::vector> workers; + auto first_submitted = WorkerKillingGroupByOwnerTest::CreateActorCreationWorker( + TaskID::FromRandom(job_id_), no_retry_); + auto second_submitted = WorkerKillingGroupByOwnerTest::CreateTaskWorker( + TaskID::FromRandom(job_id_), no_retry_); + workers.push_back(first_submitted); + workers.push_back(second_submitted); + + std::vector, bool>> expected; + expected.push_back(std::make_pair(second_submitted, should_not_retry_)); + expected.push_back(std::make_pair(first_submitted, should_not_retry_)); + + for (const auto &entry : expected) { + auto worker_to_kill_and_should_retry_ = + worker_killing_policy_.SelectWorkerToKill(workers, MemorySnapshot()); + auto worker_to_kill = worker_to_kill_and_should_retry_.first; + bool retry = worker_to_kill_and_should_retry_.second; + ASSERT_EQ(worker_to_kill->WorkerId(), entry.first->WorkerId()); + ASSERT_EQ(retry, entry.second); + workers.erase(std::remove(workers.begin(), workers.end(), worker_to_kill), + workers.end()); + } +} + +} // namespace raylet + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/raylet/worker_killing_policy_test.cc b/src/ray/raylet/worker_killing_policy_test.cc index ce2edbb0edd4..d5588491f6eb 100644 --- a/src/ray/raylet/worker_killing_policy_test.cc +++ b/src/ray/raylet/worker_killing_policy_test.cc @@ -14,8 +14,6 @@ #include "ray/raylet/worker_killing_policy.h" -#include - #include "gtest/gtest.h" #include "ray/common/task/task_spec.h" #include "ray/raylet/test/util.h"