Skip to content

Commit

Permalink
[aDAG] Support multi-read of the same shm channel
Browse files Browse the repository at this point in the history
Signed-off-by: Rui Qiao <[email protected]>
  • Loading branch information
ruisearch42 committed Aug 30, 2024
1 parent f0a81a6 commit 7eff3cc
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 52 deletions.
103 changes: 87 additions & 16 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import traceback
from typing import NamedTuple

from ray.experimental.channel.cached_channel import CachedChannel
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.util.annotations import PublicAPI
Expand Down Expand Up @@ -247,13 +248,14 @@ def __init__(
self.input_variant = input_variant
self.channel_idx = channel_idx

def resolve(self, channel_results: Any):
def resolve(self, channel_results: Any) -> Any:
"""
Resolve the input value from the channel results.
Args:
channel_results: The results from reading the input channels.
"""

if isinstance(self.input_variant, ChannelInterface):
value = channel_results[self.channel_idx]
elif isinstance(self.input_variant, DAGInputAdapter):
Expand Down Expand Up @@ -455,7 +457,11 @@ def _write(self) -> bool:
exit = True
return exit

def exec_operation(self, class_handle, op_type: _DAGNodeOperationType) -> bool:
def exec_operation(
self,
class_handle,
op_type: _DAGNodeOperationType,
) -> bool:
"""
An ExecutableTask corresponds to a DAGNode. It consists of three
operations: READ, COMPUTE, and WRITE, which should be executed in
Expand Down Expand Up @@ -1023,12 +1029,17 @@ def _get_or_compile(
)
)
else:
# Use reader_handles_set to deduplicate readers on the same actor,
# because with CachedChannel each actor will only read from the
# upstream channel once.
reader_handles_set = set()
for reader in readers:
reader_handle = reader.dag_node._get_actor_handle()
reader_and_node_list.append(
(reader_handle, self._get_node_id(reader_handle))
)

if reader_handle not in reader_handles_set:
reader_and_node_list.append(
(reader_handle, self._get_node_id(reader_handle))
)
reader_handles_set.add(reader_handle)
fn = task.dag_node._get_remote_method("__ray_call__")
task.output_channel = ray.get(
fn.remote(
Expand Down Expand Up @@ -1110,34 +1121,94 @@ def _get_or_compile(

# Create executable tasks for each actor
for actor_handle, tasks in self.actor_to_tasks.items():
executable_tasks = []
# Dict from the arg to the set of tasks that consume it.
arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set)
# The number of tasks that consume InputNode (or InputAttributeNode)
# Note that _preprocess() ensures that all tasks either use InputNode
# or use InputAttributeNode, but not both.
num_input_consumers = 0

# Step 1: populate num_channel_reads and perform some validation.
for task in tasks:
resolved_args = []
has_at_least_one_channel_input = False
for arg in task.args:
if isinstance(arg, InputNode):
input_adapter = DAGInputAdapter(None, self.dag_input_channel)
resolved_args.append(input_adapter)
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, InputAttributeNode):
input_adapter = DAGInputAdapter(arg, self.dag_input_channel)
resolved_args.append(input_adapter)
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, DAGNode): # Other DAGNodes
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
assert arg_channel is not None
resolved_args.append(arg_channel)
has_at_least_one_channel_input = True
else:
resolved_args.append(arg)
# TODO: Support no-input DAGs (use an empty object to signal).
if not has_at_least_one_channel_input:
raise ValueError(
"Compiled DAGs require each task to take a "
"ray.dag.InputNode or at least one other DAGNode as an "
"input"
)

# Step 2: create cached channels if needed

# Dict from original channel to the channel to be used in execution.
# The value of this dict is either the original channel or a newly
# created CachedChannel (if the original channel is read more than once).
channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
for arg, consumers in arg_to_consumers.items():
# Handle non-input args
if not isinstance(arg, InputNode) and not isinstance(
arg, InputAttributeNode
):
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
if len(consumers) > 1:
channel_dict[arg_channel] = CachedChannel(
len(consumers),
arg_channel,
)
else:
channel_dict[arg_channel] = arg_channel
# Handle input args
if num_input_consumers > 1:
channel_dict[self.dag_input_channel] = CachedChannel(
num_input_consumers,
self.dag_input_channel,
)
else:
channel_dict[self.dag_input_channel] = self.dag_input_channel

# Step 3: create executable tasks for the actor
executable_tasks = []
for task in tasks:
resolved_args = []
for arg in task.args:
if isinstance(arg, InputNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(None, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, InputAttributeNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(arg, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, DAGNode): # Other DAGNodes
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
assert arg_channel is not None
arg_channel = channel_dict[arg_channel]
resolved_args.append(arg_channel)
else:
# Constant arg
resolved_args.append(arg)
executable_task = ExecutableTask(
task,
resolved_args,
Expand Down
138 changes: 108 additions & 30 deletions python/ray/dag/tests/experimental/test_accelerated_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,82 @@ def test_actor_method_bind_same_constant(ray_start_regular):
compiled_dag.teardown()


def test_actor_method_bind_same_input(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as inp:
# Test binding input node to the same method
# of same actor multiple times: execution
# should not hang.
output1 = actor.inc.bind(inp)
output2 = actor.inc.bind(inp)
dag = MultiOutputNode([output1, output2])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_actor_method_bind_same_input_attr(ray_start_regular):
actor = Actor.remote(0)
with InputNode() as inp:
# Test binding input attribute node to the same method
# of same actor multiple times: execution should not
# hang.
output1 = actor.inc.bind(inp[0])
output2 = actor.inc.bind(inp[0])
dag = MultiOutputNode([output1, output2])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_actor_method_bind_same_arg(ray_start_regular):
a1 = Actor.remote(0)
a2 = Actor.remote(0)
with InputNode() as inp:
# Test binding arg to the same method
# of same actor multiple times: execution
# should not hang.
output1 = a1.echo.bind(inp)
output2 = a2.inc.bind(output1)
output3 = a2.inc.bind(output1)
dag = MultiOutputNode([output2, output3])
compiled_dag = dag.experimental_compile()
expected = [[0, 0], [1, 2], [4, 6]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_mixed_bind_same_input(ray_start_regular):
a1 = Actor.remote(0)
a2 = Actor.remote(0)
with InputNode() as inp:
# Test binding input node to the same method
# of different actors multiple times: execution
# should not hang.
output1 = a1.inc.bind(inp)
output2 = a1.inc.bind(inp)
output3 = a2.inc.bind(inp)
dag = MultiOutputNode([output1, output2, output3])
compiled_dag = dag.experimental_compile()
expected = [[0, 0, 0], [1, 2, 1], [4, 6, 3]]
for i in range(3):
ref = compiled_dag.execute(i)
result = ray.get(ref)
assert result == expected[i]
compiled_dag.teardown()


def test_regular_args(ray_start_regular):
# Test passing regular args to .bind in addition to DAGNode args.
a = Actor.remote(0)
Expand Down Expand Up @@ -304,14 +380,15 @@ def test_multi_args_basic(ray_start_regular):
def test_multi_args_single_actor(ray_start_regular):
c = Collector.remote()
with InputNode() as i:
dag = c.collect_two.bind(i[1], i[0])
dag = c.collect_three.bind(i[0], i[1], i[0])

compiled_dag = dag.experimental_compile()

expected = [[0, 1, 0], [0, 1, 0, 1, 2, 1], [0, 1, 0, 1, 2, 1, 2, 3, 2]]
for i in range(3):
ref = compiled_dag.execute(2, 3)
ref = compiled_dag.execute(i, i + 1)
result = ray.get(ref)
assert result == [3, 2] * (i + 1)
assert result == expected[i]

with pytest.raises(
ValueError,
Expand Down Expand Up @@ -1424,6 +1501,24 @@ def get_node_id(self):
compiled_dag.teardown()


@ray.remote
class TestWorker:
def add_one(self, value):
return value + 1

def add(self, val1, val2):
return val1 + val2

def generate_torch_tensor(self, size) -> torch.Tensor:
return torch.zeros(size)

def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor:
"""
Add `value` to all elements of the tensor.
"""
return tensor + value


class TestActorInputOutput:
"""
Accelerated DAGs support the following two cases for the input/output of the graph:
Expand All @@ -1436,23 +1531,6 @@ class TestActorInputOutput:
which is an actor, needs to be the input and output of the graph.
"""

@ray.remote
class Worker:
def add_one(self, value):
return value + 1

def add(self, val1, val2):
return val1 + val2

def generate_torch_tensor(self, size) -> torch.Tensor:
return torch.zeros(size)

def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor:
"""
Add `value` to all elements of the tensor.
"""
return tensor + value

def test_shared_memory_channel_only(ray_start_cluster):
"""
Replica -> Worker -> Replica
Expand All @@ -1463,7 +1541,7 @@ def test_shared_memory_channel_only(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self.w = TestActorInputOutput.Worker.remote()
self.w = TestWorker.remote()
with InputNode() as inp:
dag = self.w.add_one.bind(inp)
self.compiled_dag = dag.experimental_compile()
Expand All @@ -1487,7 +1565,7 @@ def test_intra_process_channel(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self.w = TestActorInputOutput.Worker.remote()
self.w = TestWorker.remote()
with InputNode() as inp:
dag = self.w.add_one.bind(inp)
dag = self.w.add_one.bind(dag)
Expand All @@ -1512,8 +1590,8 @@ def test_multiple_readers_multiple_writers(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
dag = MultiOutputNode([w1.add_one.bind(inp), w2.add_one.bind(inp)])
self.compiled_dag = dag.experimental_compile()
Expand All @@ -1539,8 +1617,8 @@ def test_multiple_readers_single_writer(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
branch1 = w1.add_one.bind(inp)
branch2 = w2.add_one.bind(inp)
Expand All @@ -1567,8 +1645,8 @@ def test_single_reader_multiple_writers(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
w1 = TestActorInputOutput.Worker.remote()
w2 = TestActorInputOutput.Worker.remote()
w1 = TestWorker.remote()
w2 = TestWorker.remote()
with InputNode() as inp:
dag = w1.add_one.bind(inp)
dag = MultiOutputNode([w1.add_one.bind(dag), w2.add_one.bind(dag)])
Expand All @@ -1592,8 +1670,8 @@ def test_torch_tensor_type(ray_start_cluster):
@ray.remote
class Replica:
def __init__(self):
self._base = TestActorInputOutput.Worker.remote()
self._refiner = TestActorInputOutput.Worker.remote()
self._base = TestWorker.remote()
self._refiner = TestWorker.remote()

with ray.dag.InputNode() as inp:
dag = self._refiner.add_value_to_tensor.bind(
Expand Down
2 changes: 2 additions & 0 deletions python/ray/experimental/channel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.common import ( # noqa: F401
AwaitableBackgroundReader,
AwaitableBackgroundWriter,
Expand All @@ -16,6 +17,7 @@
__all__ = [
"AwaitableBackgroundReader",
"AwaitableBackgroundWriter",
"CachedChannel",
"Channel",
"ReaderInterface",
"SynchronousReader",
Expand Down
Loading

0 comments on commit 7eff3cc

Please sign in to comment.