Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] runtime context resource ids getter #26907

Merged
merged 9 commits into from
Jul 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def get_gpu_ids():
return assigned_ids


@Deprecated
@Deprecated(message="Use ray.get_runtime_context().assigned_resources instead.")
def get_resource_ids():
"""Get the IDs of the resources that are available to the worker.

Expand Down
22 changes: 21 additions & 1 deletion python/ray/runtime_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Dict

import ray._private.worker
from ray._private.client_mode_hook import client_mode_hook
Expand All @@ -16,7 +17,7 @@ def __init__(self, worker):
assert worker is not None
self.worker = worker

def get(self):
def get(self) -> Dict[str, Any]:
"""Get a dictionary of the current context.

Returns:
Expand Down Expand Up @@ -160,6 +161,25 @@ def should_capture_child_tasks_in_placement_group(self):
"""
return self.worker.should_capture_child_tasks_in_placement_group

def get_assigned_resources(self):
"""Get the assigned resources to this worker.

By default for tasks, this will return {"CPU": 1}.
By default for actors, this will return {}.

Returns:
A dictionary mapping the name of a resource to a float, where
the float represents the amount of that resource reserved
for this worker.
"""
self.worker.check_connected()
resource_id_map = self.worker.core_worker.resource_ids()
richardliaw marked this conversation as resolved.
Show resolved Hide resolved
resource_map = {
res: sum(amt for _, amt in mapping)
for res, mapping in resource_id_map.items()
}
return resource_map

def get_runtime_env_string(self):
"""Get the runtime env string used for the current driver or worker.

Expand Down
6 changes: 3 additions & 3 deletions python/ray/tests/test_placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_placement_group_actor_resource_ids(ray_start_cluster, connect_to_client
@ray.remote(num_cpus=1)
class F:
def f(self):
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().assigned_resources

cluster = ray_start_cluster
num_nodes = 1
Expand All @@ -391,7 +391,7 @@ def f(self):
def test_placement_group_task_resource_ids(ray_start_cluster, connect_to_client):
@ray.remote(num_cpus=1)
def f():
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().assigned_resources

cluster = ray_start_cluster
num_nodes = 1
Expand Down Expand Up @@ -423,7 +423,7 @@ def f():
def test_placement_group_hang(ray_start_cluster, connect_to_client):
@ray.remote(num_cpus=1)
def f():
return ray._private.worker.get_resource_ids()
return ray.get_runtime_context().assigned_resources

cluster = ray_start_cluster
num_nodes = 1
Expand Down
28 changes: 28 additions & 0 deletions python/ray/tests/test_runtime_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,34 @@ def echo2(self, s):
assert ray.get(ray.get(obj)) == "hello"


def test_get_assigned_resources(ray_start_10_cpus):
@ray.remote
class Echo:
def check(self):
return ray.get_runtime_context().get_assigned_resources()

e = Echo.remote()
result = e.check.remote()
print(ray.get(result))
assert ray.get(result).get("CPU") is None
ray.kill(e)

e = Echo.options(num_cpus=4).remote()
result = e.check.remote()
assert ray.get(result)["CPU"] == 4.0
ray.kill(e)

@ray.remote
def check():
return ray.get_runtime_context().get_assigned_resources()

result = check.remote()
assert ray.get(result)["CPU"] == 1.0

result = check.options(num_cpus=2).remote()
assert ray.get(result)["CPU"] == 2.0


def test_actor_stats_normal_task(ray_start_regular):
# Because it works at the core worker level, this API works for tasks.
@ray.remote
Expand Down