diff --git a/dashboard/state_aggregator.py b/dashboard/state_aggregator.py index e3aa6f45ba4f..922644a77bfd 100644 --- a/dashboard/state_aggregator.py +++ b/dashboard/state_aggregator.py @@ -202,7 +202,9 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: """ try: - reply = await self._client.get_all_actor_info(timeout=option.timeout) + reply = await self._client.get_all_actor_info( + timeout=option.timeout, filters=option.filters + ) except DataSourceUnavailable: raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING) @@ -219,7 +221,7 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse: ], ) result.append(data) - num_after_truncation = len(result) + num_after_truncation = len(result) + reply.num_filtered result = self._filter(result, option.filters, ActorState, option.detail) num_filtered = len(result) diff --git a/python/ray/experimental/state/state_manager.py b/python/ray/experimental/state/state_manager.py index 70d22b8fa070..11ea98b89c4c 100644 --- a/python/ray/experimental/state/state_manager.py +++ b/python/ray/experimental/state/state_manager.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from functools import wraps -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import grpc from grpc.aio._call import UnaryStreamCall @@ -12,8 +12,9 @@ from ray._private import ray_constants from ray._private.gcs_utils import GcsAioClient from ray._private.utils import hex_to_binary -from ray._raylet import JobID +from ray._raylet import ActorID, JobID from ray.core.generated import gcs_service_pb2_grpc +from ray.core.generated.gcs_pb2 import ActorTableData from ray.core.generated.gcs_service_pb2 import ( GetAllActorInfoReply, GetAllActorInfoRequest, @@ -47,7 +48,11 @@ from ray.dashboard.datacenter import DataSource from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient from ray.dashboard.utils import Dict as Dictionary -from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE +from ray.experimental.state.common import ( + RAY_MAX_LIMIT_FROM_DATA_SOURCE, + PredicateType, + SupportedFilterType, +) from ray.experimental.state.exception import DataSourceUnavailable logger = logging.getLogger(__name__) @@ -221,12 +226,32 @@ def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]: @handle_grpc_network_errors async def get_all_actor_info( - self, timeout: int = None, limit: int = None + self, + timeout: int = None, + limit: int = None, + filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None, ) -> Optional[GetAllActorInfoReply]: if not limit: limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE - - request = GetAllActorInfoRequest(limit=limit) + if filters is None: + filters = [] + + req_filters = GetAllActorInfoRequest.Filters() + for filter in filters: + key, predicate, value = filter + if predicate != "=": + # We only support EQUAL predicate for source side filtering. + continue + if key == "actor_id": + req_filters.actor_id = ActorID(hex_to_binary(value)).binary() + elif key == "state": + if value not in ActorTableData.ActorState.keys(): + raise ValueError(f"Invalid actor state for filtering: {value}") + req_filters.state = ActorTableData.ActorState.Value(value) + elif key == "job_id": + req_filters.job_id = JobID(hex_to_binary(value)).binary() + + request = GetAllActorInfoRequest(limit=limit, filters=req_filters) reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=timeout ) diff --git a/python/ray/tests/test_state_api.py b/python/ray/tests/test_state_api.py index 8181d3a0ba4a..5ab1a568ac88 100644 --- a/python/ray/tests/test_state_api.py +++ b/python/ray/tests/test_state_api.py @@ -18,7 +18,9 @@ import ray.dashboard.consts as dashboard_consts import ray._private.state as global_state import ray._private.ray_constants as ray_constants +from ray._raylet import ActorID from ray._private.test_utils import ( + run_string_as_driver, wait_for_condition, async_wait_for_condition_async_predicate, ) @@ -132,6 +134,18 @@ def state_api_manager(): yield manager +@pytest.fixture +def state_api_manager_e2e(ray_start_with_dashboard): + address_info = ray_start_with_dashboard + gcs_address = address_info["gcs_address"] + gcs_aio_client = GcsAioClient(address=gcs_address) + gcs_channel = gcs_aio_client.channel.channel() + state_api_data_source_client = StateDataSourceClient(gcs_channel, gcs_aio_client) + manager = StateAPIManager(state_api_data_source_client) + + yield manager + + def verify_schema(state, result_dict: dict, detail: bool = False): state_fields_columns = set() if detail: @@ -495,6 +509,87 @@ def test_state_api_client_periodic_warning(shutdown_only, capsys, clear_loggers) expected_line in lines +@pytest.mark.asyncio +async def test_api_manager_e2e_list_actors(state_api_manager_e2e): + @ray.remote + class Actor: + pass + + a = Actor.remote() + script = """ +import ray + +ray.init("auto") + +@ray.remote +class Actor: + pass + + def ready(self): + pass + +b = Actor.remote() +ray.get(b.ready.remote()) +del b + """ + + run_string_as_driver(script) + + async def verify(): + result = await state_api_manager_e2e.list_actors(option=create_api_options()) + print(result) + assert result.total == 2 + assert result.num_after_truncation == 2 + return True + + await async_wait_for_condition_async_predicate(verify) + + async def verify(): + # Test actor id filtering on source + result = await state_api_manager_e2e.list_actors( + option=create_api_options(filters=[("actor_id", "=", a._actor_id.hex())]) + ) + print(result) + assert result.num_after_truncation == 2 + assert len(result.result) == 1 + return True + + await async_wait_for_condition_async_predicate(verify) + + async def verify(): + # Test state filtering on source + result = await state_api_manager_e2e.list_actors( + option=create_api_options(filters=[("state", "=", "ALIVE")]) + ) + assert result.num_after_truncation == 2 + assert len(result.result) == 1 + return True + + await async_wait_for_condition_async_predicate(verify) + + async def verify(): + # Test job filtering on source + cur_job_id = ray.get_runtime_context().get_job_id() + result = await state_api_manager_e2e.list_actors( + option=create_api_options(filters=[("job_id", "=", cur_job_id)]) + ) + assert result.num_after_truncation == 2 + assert len(result.result) == 1 + return True + + await async_wait_for_condition_async_predicate(verify) + + async def verify(): + with pytest.raises(ValueError): + await state_api_manager_e2e.list_actors( + option=create_api_options(filters=[("state", "=", "DEEEED")]) + ) + + return True + + await async_wait_for_condition_async_predicate(verify) + + @pytest.mark.asyncio async def test_api_manager_list_actors(state_api_manager): data_source_client = state_api_manager.data_source_client @@ -538,9 +633,7 @@ async def test_api_manager_list_actors(state_api_manager): result = await state_api_manager.list_actors( option=create_api_options(filters=[("stat", "=", "DEAD")]) ) - result = await state_api_manager.list_actors( - option=create_api_options(filters=[("state", "=", "DEAD")]) - ) + assert len(result.result) == 1 """ @@ -3307,9 +3400,10 @@ def test_get_id_not_found(shutdown_only): """ ray.init() runner = CliRunner() - result = runner.invoke(ray_get, ["actors", "1234"]) - assert result.exit_code == 0 - assert "Resource with id=1234 not found in the cluster." in result.output + id = ActorID.from_random().hex() + result = runner.invoke(ray_get, ["actors", id]) + assert result.exit_code == 0, str(result.exception) + result.output + assert f"Resource with id={id} not found in the cluster." in result.output def test_core_state_api_usage_tags(shutdown_only): diff --git a/src/ray/gcs/gcs_server/gcs_actor_manager.cc b/src/ray/gcs/gcs_server/gcs_actor_manager.cc index fe8bcb220497..ee328510ea82 100644 --- a/src/ray/gcs/gcs_server/gcs_actor_manager.cc +++ b/src/ray/gcs/gcs_server/gcs_actor_manager.cc @@ -338,15 +338,41 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request, auto limit = request.has_limit() ? request.limit() : -1; RAY_LOG(DEBUG) << "Getting all actor info."; ++counts_[CountType::GET_ALL_ACTOR_INFO_REQUEST]; + + const auto filter_fn = [](const rpc::GetAllActorInfoRequest::Filters &filters, + const rpc::ActorTableData &data) { + if (filters.has_actor_id() && + ActorID::FromBinary(filters.actor_id()) != ActorID::FromBinary(data.actor_id())) { + return false; + } + if (filters.has_job_id() && + JobID::FromBinary(filters.job_id()) != JobID::FromBinary(data.job_id())) { + return false; + } + if (filters.has_state() && filters.state() != data.state()) { + return false; + } + return true; + }; + if (request.show_dead_jobs() == false) { auto total_actors = registered_actors_.size() + destroyed_actors_.size(); reply->set_total(total_actors); auto count = 0; + auto num_filtered = 0; for (const auto &iter : registered_actors_) { if (limit != -1 && count >= limit) { break; } + + // With filters, skip the actor if it doesn't match the filter. + if (request.has_filters() && + !filter_fn(request.filters(), iter.second->GetActorTableData())) { + ++num_filtered; + continue; + } + count += 1; *reply->add_actor_table_data() = iter.second->GetActorTableData(); } @@ -355,9 +381,17 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request, if (limit != -1 && count >= limit) { break; } + // With filters, skip the actor if it doesn't match the filter. + if (request.has_filters() && + !filter_fn(request.filters(), iter.second->GetActorTableData())) { + ++num_filtered; + continue; + } + count += 1; *reply->add_actor_table_data() = iter.second->GetActorTableData(); } + reply->set_num_filtered(num_filtered); RAY_LOG(DEBUG) << "Finished getting all actor info."; GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); return; @@ -367,7 +401,7 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request, // We don't maintain an in-memory cache of all actors which belong to dead // jobs, so fetch it from redis. Status status = gcs_table_storage_->ActorTable().GetAll( - [reply, send_reply_callback, limit]( + [reply, send_reply_callback, limit, request, filter_fn]( absl::flat_hash_map &&result) { auto total_actors = result.size(); @@ -377,16 +411,23 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request, auto ptr = google::protobuf::Arena::Create< absl::flat_hash_map>(arena, std::move(result)); auto count = 0; + auto num_filtered = 0; for (const auto &pair : *ptr) { if (limit != -1 && count >= limit) { break; } + // With filters, skip the actor if it doesn't match the filter. + if (request.has_filters() && !filter_fn(request.filters(), pair.second)) { + ++num_filtered; + continue; + } count += 1; // TODO yic: Fix const cast reply->mutable_actor_table_data()->UnsafeArenaAddAllocated( const_cast(&pair.second)); } + reply->set_num_filtered(num_filtered); GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); RAY_LOG(DEBUG) << "Finished getting all actor info."; }); diff --git a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc index 77eb088620d0..9e99f2270027 100644 --- a/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc +++ b/src/ray/gcs/gcs_server/test/gcs_actor_manager_test.cc @@ -242,6 +242,32 @@ class GcsActorManagerTest : public ::testing::Test { promise.get_future().get(); } + std::shared_ptr CreateActorAndWaitTilAlive(const JobID &job_id) { + auto registered_actor = RegisterActor(job_id); + rpc::CreateActorRequest create_actor_request; + create_actor_request.mutable_task_spec()->CopyFrom( + registered_actor->GetCreationTaskSpecification().GetMessage()); + std::vector> finished_actors; + Status status = gcs_actor_manager_->CreateActor( + create_actor_request, + [&finished_actors](const std::shared_ptr &actor, + const rpc::PushTaskReply &reply, + const Status &status) { + finished_actors.emplace_back(actor); + }); + + auto actor = mock_actor_scheduler_->actors.back(); + mock_actor_scheduler_->actors.pop_back(); + + // Check that the actor is in state `ALIVE`. + actor->UpdateAddress(RandomAddress()); + gcs_actor_manager_->OnActorCreationSuccess(actor, rpc::PushTaskReply()); + WaitActorCreated(actor->GetActorID()); + RAY_CHECK_EQ(gcs_actor_manager_->CountFor(rpc::ActorTableData::ALIVE, ""), 1); + RAY_CHECK_EQ(actor->GetState(), rpc::ActorTableData::ALIVE); + return actor; + } + instrumented_io_context io_service_; std::unique_ptr thread_io_service_; std::shared_ptr store_client_; @@ -1178,6 +1204,88 @@ TEST_F(GcsActorManagerTest, TestReuseActorNameInNamespace) { } } +TEST_F(GcsActorManagerTest, TestGetAllActorInfoFilters) { + google::protobuf::Arena arena; + // The target filter actor. + auto job_id = JobID::FromInt(1); + auto actor = CreateActorAndWaitTilAlive(job_id); + + // Just register some other actors. + auto job_id_other = JobID::FromInt(2); + auto num_other_actors = 3; + for (int i = 0; i < num_other_actors; i++) { + auto request1 = Mocker::GenRegisterActorRequest(job_id_other, + /*max_restarts=*/0, + /*detached=*/false); + Status status = gcs_actor_manager_->RegisterActor( + request1, [](std::shared_ptr actor) {}); + ASSERT_TRUE(status.ok()); + } + + auto callback = + [](Status status, std::function success, std::function failure) {}; + // Filter with actor id + { + rpc::GetAllActorInfoRequest request; + request.mutable_filters()->set_actor_id(actor->GetActorID().Binary()); + + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); + gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); + ASSERT_EQ(reply.actor_table_data().size(), 1); + ASSERT_EQ(reply.total(), 1 + num_other_actors); + ASSERT_EQ(reply.num_filtered(), num_other_actors); + } + + // Filter with job id + { + rpc::GetAllActorInfoRequest request; + request.mutable_filters()->set_job_id(job_id.Binary()); + + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); + gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); + ASSERT_EQ(reply.actor_table_data().size(), 1); + ASSERT_EQ(reply.num_filtered(), num_other_actors); + } + + // Filter with states + { + rpc::GetAllActorInfoRequest request; + request.mutable_filters()->set_state(rpc::ActorTableData::ALIVE); + + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); + gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); + ASSERT_EQ(reply.actor_table_data().size(), 1); + ASSERT_EQ(reply.num_filtered(), num_other_actors); + } + + // Simple test AND + { + rpc::GetAllActorInfoRequest request; + request.mutable_filters()->set_state(rpc::ActorTableData::ALIVE); + request.mutable_filters()->set_job_id(job_id.Binary()); + + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); + gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); + ASSERT_EQ(reply.actor_table_data().size(), 1); + ASSERT_EQ(reply.num_filtered(), num_other_actors); + } + { + rpc::GetAllActorInfoRequest request; + request.mutable_filters()->set_state(rpc::ActorTableData::DEAD); + request.mutable_filters()->set_job_id(job_id.Binary()); + + auto &reply = + *google::protobuf::Arena::CreateMessage(&arena); + gcs_actor_manager_->HandleGetAllActorInfo(request, &reply, callback); + ASSERT_EQ(reply.num_filtered(), num_other_actors + 1); + ASSERT_EQ(reply.actor_table_data().size(), 0); + } +} + TEST_F(GcsActorManagerTest, TestGetAllActorInfoLimit) { google::protobuf::Arena arena; auto job_id_1 = JobID::FromInt(1); diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index 43b18d9c0a50..9c63d76c4130 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -120,6 +120,17 @@ message GetAllActorInfoRequest { // Maximum number of entries to return. // If not specified, return the whole entries without truncation. optional int64 limit = 2; + + // The filter to apply to the returned entries. + message Filters { + // Actor id + optional bytes actor_id = 1; + // Job id + optional bytes job_id = 2; + // Actor state + optional ActorTableData.ActorState state = 3; + } + optional Filters filters = 3; } message GetAllActorInfoReply { @@ -128,6 +139,8 @@ message GetAllActorInfoReply { repeated ActorTableData actor_table_data = 2; // Length of the corresponding resource without truncation. int64 total = 3; + // Number of results filtered on the source. + int64 num_filtered = 4; } // `KillActorViaGcsRequest` is sent to GCS Service to ask to kill an actor.