Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][state] Efficient get/list actors with filters on some high-cardinality fields #34348

Merged
merged 1 commit into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions dashboard/state_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:

"""
try:
reply = await self._client.get_all_actor_info(timeout=option.timeout)
reply = await self._client.get_all_actor_info(
timeout=option.timeout, filters=option.filters
)
except DataSourceUnavailable:
raise DataSourceUnavailable(GCS_QUERY_FAILURE_WARNING)

Expand All @@ -219,7 +221,7 @@ async def list_actors(self, *, option: ListApiOptions) -> ListApiResponse:
],
)
result.append(data)
num_after_truncation = len(result)
num_after_truncation = len(result) + reply.num_filtered
result = self._filter(result, option.filters, ActorState, option.detail)
num_filtered = len(result)

Expand Down
37 changes: 31 additions & 6 deletions python/ray/experimental/state/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from functools import wraps
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple

import grpc
from grpc.aio._call import UnaryStreamCall
Expand All @@ -12,8 +12,9 @@
from ray._private import ray_constants
from ray._private.gcs_utils import GcsAioClient
from ray._private.utils import hex_to_binary
from ray._raylet import JobID
from ray._raylet import ActorID, JobID
from ray.core.generated import gcs_service_pb2_grpc
from ray.core.generated.gcs_pb2 import ActorTableData
from ray.core.generated.gcs_service_pb2 import (
GetAllActorInfoReply,
GetAllActorInfoRequest,
Expand Down Expand Up @@ -47,7 +48,11 @@
from ray.dashboard.datacenter import DataSource
from ray.dashboard.modules.job.common import JobInfo, JobInfoStorageClient
from ray.dashboard.utils import Dict as Dictionary
from ray.experimental.state.common import RAY_MAX_LIMIT_FROM_DATA_SOURCE
from ray.experimental.state.common import (
RAY_MAX_LIMIT_FROM_DATA_SOURCE,
PredicateType,
SupportedFilterType,
)
from ray.experimental.state.exception import DataSourceUnavailable

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -221,12 +226,32 @@ def ip_to_node_id(self, ip: Optional[str]) -> Optional[str]:

@handle_grpc_network_errors
async def get_all_actor_info(
self, timeout: int = None, limit: int = None
self,
timeout: int = None,
limit: int = None,
filters: Optional[List[Tuple[str, PredicateType, SupportedFilterType]]] = None,
) -> Optional[GetAllActorInfoReply]:
if not limit:
limit = RAY_MAX_LIMIT_FROM_DATA_SOURCE

request = GetAllActorInfoRequest(limit=limit)
if filters is None:
filters = []

req_filters = GetAllActorInfoRequest.Filters()
for filter in filters:
key, predicate, value = filter
if predicate != "=":
# We only support EQUAL predicate for source side filtering.
continue
if key == "actor_id":
req_filters.actor_id = ActorID(hex_to_binary(value)).binary()
elif key == "state":
if value not in ActorTableData.ActorState.keys():
raise ValueError(f"Invalid actor state for filtering: {value}")
req_filters.state = ActorTableData.ActorState.Value(value)
elif key == "job_id":
req_filters.job_id = JobID(hex_to_binary(value)).binary()

request = GetAllActorInfoRequest(limit=limit, filters=req_filters)
reply = await self._gcs_actor_info_stub.GetAllActorInfo(
request, timeout=timeout
)
Expand Down
106 changes: 100 additions & 6 deletions python/ray/tests/test_state_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import ray.dashboard.consts as dashboard_consts
import ray._private.state as global_state
import ray._private.ray_constants as ray_constants
from ray._raylet import ActorID
from ray._private.test_utils import (
run_string_as_driver,
wait_for_condition,
async_wait_for_condition_async_predicate,
)
Expand Down Expand Up @@ -132,6 +134,18 @@ def state_api_manager():
yield manager


@pytest.fixture
def state_api_manager_e2e(ray_start_with_dashboard):
address_info = ray_start_with_dashboard
gcs_address = address_info["gcs_address"]
gcs_aio_client = GcsAioClient(address=gcs_address)
gcs_channel = gcs_aio_client.channel.channel()
state_api_data_source_client = StateDataSourceClient(gcs_channel, gcs_aio_client)
manager = StateAPIManager(state_api_data_source_client)

yield manager


def verify_schema(state, result_dict: dict, detail: bool = False):
state_fields_columns = set()
if detail:
Expand Down Expand Up @@ -495,6 +509,87 @@ def test_state_api_client_periodic_warning(shutdown_only, capsys, clear_loggers)
expected_line in lines


@pytest.mark.asyncio
async def test_api_manager_e2e_list_actors(state_api_manager_e2e):
@ray.remote
class Actor:
pass

a = Actor.remote()
script = """
import ray

ray.init("auto")

@ray.remote
class Actor:
pass

def ready(self):
pass

b = Actor.remote()
ray.get(b.ready.remote())
del b
"""

run_string_as_driver(script)

async def verify():
result = await state_api_manager_e2e.list_actors(option=create_api_options())
print(result)
assert result.total == 2
assert result.num_after_truncation == 2
return True

await async_wait_for_condition_async_predicate(verify)

async def verify():
# Test actor id filtering on source
result = await state_api_manager_e2e.list_actors(
option=create_api_options(filters=[("actor_id", "=", a._actor_id.hex())])
)
print(result)
assert result.num_after_truncation == 2
assert len(result.result) == 1
return True

await async_wait_for_condition_async_predicate(verify)

async def verify():
# Test state filtering on source
result = await state_api_manager_e2e.list_actors(
option=create_api_options(filters=[("state", "=", "ALIVE")])
)
assert result.num_after_truncation == 2
assert len(result.result) == 1
return True

await async_wait_for_condition_async_predicate(verify)

async def verify():
# Test job filtering on source
cur_job_id = ray.get_runtime_context().get_job_id()
result = await state_api_manager_e2e.list_actors(
option=create_api_options(filters=[("job_id", "=", cur_job_id)])
)
assert result.num_after_truncation == 2
assert len(result.result) == 1
return True

await async_wait_for_condition_async_predicate(verify)

async def verify():
with pytest.raises(ValueError):
await state_api_manager_e2e.list_actors(
option=create_api_options(filters=[("state", "=", "DEEEED")])
)

return True

await async_wait_for_condition_async_predicate(verify)


@pytest.mark.asyncio
async def test_api_manager_list_actors(state_api_manager):
data_source_client = state_api_manager.data_source_client
Expand Down Expand Up @@ -538,9 +633,7 @@ async def test_api_manager_list_actors(state_api_manager):
result = await state_api_manager.list_actors(
option=create_api_options(filters=[("stat", "=", "DEAD")])
)
result = await state_api_manager.list_actors(
option=create_api_options(filters=[("state", "=", "DEAD")])
)

assert len(result.result) == 1

"""
Expand Down Expand Up @@ -3307,9 +3400,10 @@ def test_get_id_not_found(shutdown_only):
"""
ray.init()
runner = CliRunner()
result = runner.invoke(ray_get, ["actors", "1234"])
assert result.exit_code == 0
assert "Resource with id=1234 not found in the cluster." in result.output
id = ActorID.from_random().hex()
result = runner.invoke(ray_get, ["actors", id])
assert result.exit_code == 0, str(result.exception) + result.output
assert f"Resource with id={id} not found in the cluster." in result.output


def test_core_state_api_usage_tags(shutdown_only):
Expand Down
43 changes: 42 additions & 1 deletion src/ray/gcs/gcs_server/gcs_actor_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,41 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request,
auto limit = request.has_limit() ? request.limit() : -1;
RAY_LOG(DEBUG) << "Getting all actor info.";
++counts_[CountType::GET_ALL_ACTOR_INFO_REQUEST];

const auto filter_fn = [](const rpc::GetAllActorInfoRequest::Filters &filters,
const rpc::ActorTableData &data) {
if (filters.has_actor_id() &&
ActorID::FromBinary(filters.actor_id()) != ActorID::FromBinary(data.actor_id())) {
return false;
}
if (filters.has_job_id() &&
JobID::FromBinary(filters.job_id()) != JobID::FromBinary(data.job_id())) {
return false;
}
if (filters.has_state() && filters.state() != data.state()) {
return false;
}
return true;
};

if (request.show_dead_jobs() == false) {
auto total_actors = registered_actors_.size() + destroyed_actors_.size();
reply->set_total(total_actors);

auto count = 0;
auto num_filtered = 0;
for (const auto &iter : registered_actors_) {
if (limit != -1 && count >= limit) {
break;
}

// With filters, skip the actor if it doesn't match the filter.
if (request.has_filters() &&
!filter_fn(request.filters(), iter.second->GetActorTableData())) {
++num_filtered;
continue;
}

count += 1;
*reply->add_actor_table_data() = iter.second->GetActorTableData();
}
Expand All @@ -355,9 +381,17 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request,
if (limit != -1 && count >= limit) {
break;
}
// With filters, skip the actor if it doesn't match the filter.
if (request.has_filters() &&
!filter_fn(request.filters(), iter.second->GetActorTableData())) {
++num_filtered;
continue;
}

count += 1;
*reply->add_actor_table_data() = iter.second->GetActorTableData();
}
reply->set_num_filtered(num_filtered);
RAY_LOG(DEBUG) << "Finished getting all actor info.";
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
return;
Expand All @@ -367,7 +401,7 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request,
// We don't maintain an in-memory cache of all actors which belong to dead
// jobs, so fetch it from redis.
Status status = gcs_table_storage_->ActorTable().GetAll(
[reply, send_reply_callback, limit](
[reply, send_reply_callback, limit, request, filter_fn](
absl::flat_hash_map<ActorID, rpc::ActorTableData> &&result) {
auto total_actors = result.size();

Expand All @@ -377,16 +411,23 @@ void GcsActorManager::HandleGetAllActorInfo(rpc::GetAllActorInfoRequest request,
auto ptr = google::protobuf::Arena::Create<
absl::flat_hash_map<ActorID, rpc::ActorTableData>>(arena, std::move(result));
auto count = 0;
auto num_filtered = 0;
for (const auto &pair : *ptr) {
if (limit != -1 && count >= limit) {
break;
}
// With filters, skip the actor if it doesn't match the filter.
if (request.has_filters() && !filter_fn(request.filters(), pair.second)) {
++num_filtered;
continue;
}
count += 1;

// TODO yic: Fix const cast
reply->mutable_actor_table_data()->UnsafeArenaAddAllocated(
const_cast<rpc::ActorTableData *>(&pair.second));
}
reply->set_num_filtered(num_filtered);
GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK());
RAY_LOG(DEBUG) << "Finished getting all actor info.";
});
Expand Down
Loading