Skip to content

Commit

Permalink
[aDAG] Support multi-read of the same shm channel (#47311)
Browse files Browse the repository at this point in the history
If the same method of the same actor is bound to the same node (i.e., reads from the same shared memory channel), aDAG execution hangs. This PR adds support to this case by caching results read from the channel.
  • Loading branch information
ruisearch42 authored Aug 30, 2024
1 parent eedb407 commit c9c150a
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 52 deletions.
103 changes: 87 additions & 16 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
from typing import NamedTuple

from ray.experimental.channel.cached_channel import CachedChannel
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -247,13 +248,14 @@ def __init__(
self.input_variant = input_variant
self.channel_idx = channel_idx

def resolve(self, channel_results: Any):
def resolve(self, channel_results: Any) -> Any:
"""
Resolve the input value from the channel results.
Args:
channel_results: The results from reading the input channels.
"""

if isinstance(self.input_variant, ChannelInterface):
value = channel_results[self.channel_idx]
elif isinstance(self.input_variant, DAGInputAdapter):
Expand Down Expand Up @@ -455,7 +457,11 @@ def _write(self) -> bool:
exit = True
return exit

def exec_operation(self, class_handle, op_type: _DAGNodeOperationType) -> bool:
def exec_operation(
self,
class_handle,
op_type: _DAGNodeOperationType,
) -> bool:
"""
An ExecutableTask corresponds to a DAGNode. It consists of three
operations: READ, COMPUTE, and WRITE, which should be executed in
Expand Down Expand Up @@ -1023,12 +1029,17 @@ def _get_or_compile(
)
)
else:
# Use reader_handles_set to deduplicate readers on the same actor,
# because with CachedChannel each actor will only read from the
# upstream channel once.
reader_handles_set = set()
for reader in readers:
reader_handle = reader.dag_node._get_actor_handle()
reader_and_node_list.append(
(reader_handle, self._get_node_id(reader_handle))
)

if reader_handle not in reader_handles_set:
reader_and_node_list.append(
(reader_handle, self._get_node_id(reader_handle))
)
reader_handles_set.add(reader_handle)
fn = task.dag_node._get_remote_method("__ray_call__")
task.output_channel = ray.get(
fn.remote(
Expand Down Expand Up @@ -1110,34 +1121,94 @@ def _get_or_compile(

# Create executable tasks for each actor
for actor_handle, tasks in self.actor_to_tasks.items():
executable_tasks = []
# Dict from the arg to the set of tasks that consume it.
arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set)
# The number of tasks that consume InputNode (or InputAttributeNode)
# Note that _preprocess() ensures that all tasks either use InputNode
# or use InputAttributeNode, but not both.
num_input_consumers = 0

# Step 1: populate num_channel_reads and perform some validation.
for task in tasks:
resolved_args = []
has_at_least_one_channel_input = False
for arg in task.args:
if isinstance(arg, InputNode):
input_adapter = DAGInputAdapter(None, self.dag_input_channel)
resolved_args.append(input_adapter)
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, InputAttributeNode):
input_adapter = DAGInputAdapter(arg, self.dag_input_channel)
resolved_args.append(input_adapter)
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, DAGNode): # Other DAGNodes
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
assert arg_channel is not None
resolved_args.append(arg_channel)
has_at_least_one_channel_input = True
else:
resolved_args.append(arg)
# TODO: Support no-input DAGs (use an empty object to signal).
if not has_at_least_one_channel_input:
raise ValueError(
"Compiled DAGs require each task to take a "
"ray.dag.InputNode or at least one other DAGNode as an "
"input"
)

# Step 2: create cached channels if needed

# Dict from original channel to the channel to be used in execution.
# The value of this dict is either the original channel or a newly
# created CachedChannel (if the original channel is read more than once).
channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
for arg, consumers in arg_to_consumers.items():
# Handle non-input args
if not isinstance(arg, InputNode) and not isinstance(
arg, InputAttributeNode
):
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
if len(consumers) > 1:
channel_dict[arg_channel] = CachedChannel(
len(consumers),
arg_channel,
)
else:
channel_dict[arg_channel] = arg_channel
# Handle input args
if num_input_consumers > 1:
channel_dict[self.dag_input_channel] = CachedChannel(
num_input_consumers,
self.dag_input_channel,
)
else:
channel_dict[self.dag_input_channel] = self.dag_input_channel

# Step 3: create executable tasks for the actor
executable_tasks = []
for task in tasks:
resolved_args = []
for arg in task.args:
if isinstance(arg, InputNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(None, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, InputAttributeNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(arg, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, DAGNode): # Other DAGNodes
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
assert arg_channel is not None
arg_channel = channel_dict[arg_channel]
resolved_args.append(arg_channel)
else:
# Constant arg
resolved_args.append(arg)
executable_task = ExecutableTask(
task,
resolved_args,
Expand Down
138 changes: 108 additions & 30 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,82 @@ def test_actor_method_bind_same_constant(ray_start_regular):
compiled_dag.teardown()


def test_actor_method_bind_same_input(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as inp:
# Test binding input node to the same method
# of same actor multiple times: execution
# should not hang.
output1 = actor.inc.bind(inp)
output2 = actor.inc.bind(inp)
dag = MultiOutputNode([output1, output2])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_actor_method_bind_same_input_attr(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as inp:
# Test binding input attribute node to the same method
# of same actor multiple times: execution should not
# hang.
output1 = actor.inc.bind(inp[0])
output2 = actor.inc.bind(inp[0])
dag = MultiOutputNode([output1, output2])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_actor_method_bind_same_arg(ray_start_regular):
a1 = Actor.remote(0)
a2 = Actor.remote(0)
with InputNode() as inp:
# Test binding arg to the same method
# of same actor multiple times: execution
# should not hang.
output1 = a1.echo.bind(inp)
output2 = a2.inc.bind(output1)
output3 = a2.inc.bind(output1)
dag = MultiOutputNode([output2, output3])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_mixed_bind_same_input(ray_start_regular):
a1 = Actor.remote(0)
a2 = Actor.remote(0)
with InputNode() as inp:
# Test binding input node to the same method
# of different actors multiple times: execution
# should not hang.
output1 = a1.inc.bind(inp)
output2 = a1.inc.bind(inp)
output3 = a2.inc.bind(inp)
dag = MultiOutputNode([output1, output2, output3])
compiled_dag = dag.experimental_compile()
expected = [[0, 0, 0], [1, 2, 1], [4, 6, 3]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_regular_args(ray_start_regular):
# Test passing regular args to .bind in addition to DAGNode args.
a = Actor.remote(0)
Expand Down Expand Up @@ -304,14 +380,15 @@ def test_multi_args_basic(ray_start_regular):
def test_multi_args_single_actor(ray_start_regular):
c = Collector.remote()
with InputNode() as i:
dag = c.collect_two.bind(i[1], i[0])
dag = c.collect_three.bind(i[0], i[1], i[0])

compiled_dag = dag.experimental_compile()

expected = [[0, 1, 0], [0, 1, 0, 1, 2, 1], [0, 1, 0, 1, 2, 1, 2, 3, 2]]
for i in range(3):
ref = compiled_dag.execute(2, 3)
ref = compiled_dag.execute(i, i + 1)
result = ray.get(ref)
assert result == [3, 2] * (i + 1)
assert result == expected[i]

with pytest.raises(
ValueError,
Expand Down Expand Up @@ -1424,6 +1501,24 @@ def get_node_id(self):
compiled_dag.teardown()


@ray.remote
class TestWorker:
def add_one(self, value):
return value + 1

def add(self, val1, val2):
return val1 + val2

def generate_torch_tensor(self, size) -> torch.Tensor:
return torch.zeros(size)

def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor:
"""
Add `value` to all elements of the tensor.
"""
return tensor + value


class TestActorInputOutput:
"""
Accelerated DAGs support the following two cases for the input/output of the graph:
Expand All @@ -1436,23 +1531,6 @@ class TestActorInputOutput:
which is an actor, needs to be the input and output of the graph.
"""

@ray.remote
class Worker:
def add_one(self, value):
return value + 1

def add(self, val1, val2):
return val1 + val2

def generate_torch_tensor(self, size) -> torch.Tensor:
return torch.zeros(size)

def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor:
"""
Add `value` to all elements of the tensor.
"""
return tensor + value

def test_shared_memory_channel_only(ray_start_cluster):
"""
Replica -> Worker -> Replica
Expand All @@ -1463,7 +1541,7 @@ def test_shared_memory_channel_only(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self.w = TestActorInputOutput.Worker.remote()
self.w = TestWorker.remote()
with InputNode() as inp:
dag = self.w.add_one.bind(inp)
self.compiled_dag = dag.experimental_compile()
Expand All @@ -1487,7 +1565,7 @@ def test_intra_process_channel(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self.w = TestActorInputOutput.Worker.remote()
self.w = TestWorker.remote()
with InputNode() as inp:
dag = self.w.add_one.bind(inp)
dag = self.w.add_one.bind(dag)
Expand All @@ -1512,8 +1590,8 @@ def test_multiple_readers_multiple_writers(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
dag = MultiOutputNode([w1.add_one.bind(inp), w2.add_one.bind(inp)])
self.compiled_dag = dag.experimental_compile()
Expand All @@ -1539,8 +1617,8 @@ def test_multiple_readers_single_writer(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
branch1 = w1.add_one.bind(inp)
branch2 = w2.add_one.bind(inp)
Expand All @@ -1567,8 +1645,8 @@ def test_single_reader_multiple_writers(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
dag = w1.add_one.bind(inp)
dag = MultiOutputNode([w1.add_one.bind(dag), w2.add_one.bind(dag)])
Expand All @@ -1592,8 +1670,8 @@ def test_torch_tensor_type(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self._base = TestActorInputOutput.Worker.remote()
self._refiner = TestActorInputOutput.Worker.remote()
self._base = TestWorker.remote()
self._refiner = TestWorker.remote()

with ray.dag.InputNode() as inp:
dag = self._refiner.add_value_to_tensor.bind(
Expand Down
2 changes: 2 additions & 0 deletions python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.common import ( # noqa: F401
AwaitableBackgroundReader,
AwaitableBackgroundWriter,
Expand All @@ -16,6 +17,7 @@
__all__ = [
"AwaitableBackgroundReader",
"AwaitableBackgroundWriter",
"CachedChannel",
"Channel",
"ReaderInterface",
"SynchronousReader",
Expand Down
Loading

0 comments on commit c9c150a

Please sign in to comment.