Skip to content

Commit

Permalink
[aDAG] Support multi-read of the same shm channel (ray-project#47311)
Browse files Browse the repository at this point in the history
If the same method of the same actor is bound to the same node (i.e., reads from the same shared memory channel), aDAG execution hangs. This PR adds support to this case by caching results read from the channel.

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
ruisearch42 authored and ujjawal-khare committed Oct 15, 2024
1 parent a235987 commit 785a594
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 246 deletions.
151 changes: 116 additions & 35 deletions python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import traceback

from ray.experimental.channel.cached_channel import CachedChannel
from ray.experimental.channel.gpu_communicator import GPUCommunicator
import ray
from ray.exceptions import RayTaskError, RayChannelError
from ray.experimental.compiled_dag_ref import (
Expand Down Expand Up @@ -1001,20 +1000,71 @@ def _get_or_compile(
if type_hint.requires_nccl():
type_hint.set_nccl_group_id(self._nccl_group_id)

if (
isinstance(task.dag_node, ClassMethodNode)
and task.dag_node.is_class_method_call
):
# Create output buffers for the actor method.
assert len(task.output_channels) == 0
# `output_to_readers` stores the reader tasks for each output of
# the current node. If the current node returns one output, the
# readers are the downstream nodes of the current node. If the
# current node returns multiple outputs, the readers of each
# output are the downstream nodes of the ClassMethodNode that
# is a class method output.
output_to_readers: Dict[CompiledTask, List[CompiledTask]] = defaultdict(
list
if isinstance(task.dag_node, ClassMethodNode):
# `readers` is the nodes that are ordered after the current one (`task`)
# in the DAG.
readers = [self.idx_to_task[idx] for idx in task.downstream_task_idxs]
reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = []
dag_nodes = [reader.dag_node for reader in readers]
read_by_multi_output_node = False
for dag_node in dag_nodes:
if isinstance(dag_node, MultiOutputNode):
read_by_multi_output_node = True
break
if read_by_multi_output_node:
if len(readers) != 1:
raise ValueError(
"DAG outputs currently can only be read by the driver or "
"the same actor that is also the InputNode, not by both "
"the driver and actors."
)
# This node is a multi-output node, which means it will only be
# read by the driver or the actor that is also the InputNode.

# TODO(jhumphri): Handle case where there is an actor, other than
# just the driver actor, also reading the output from the `task`
# node.
# For example, the following currently does not work:
# def test_blah(ray_start_regular):
# a = Actor.remote(0)
# b = Actor.remote(10)
# with InputNode() as inp:
# x = a.inc.bind(inp)
# y = b.inc.bind(x)
# dag = MultiOutputNode([x, y])

# compiled_dag = dag.experimental_compile()
# output_channel = compiled_dag.execute(1)
# result = output_channel.read()
# print(result)

# compiled_dag.teardown()
assert self._creator_or_proxy_actor is not None
reader_and_node_list.append(
(
self._creator_or_proxy_actor,
self._get_node_id(self._creator_or_proxy_actor),
)
)
else:
# Use reader_handles_set to deduplicate readers on the same actor,
# because with CachedChannel each actor will only read from the
# upstream channel once.
reader_handles_set = set()
for reader in readers:
reader_handle = reader.dag_node._get_actor_handle()
if reader_handle not in reader_handles_set:
reader_and_node_list.append(
(reader_handle, self._get_node_id(reader_handle))
)
reader_handles_set.add(reader_handle)
fn = task.dag_node._get_remote_method("__ray_call__")
task.output_channel = ray.get(
fn.remote(
do_allocate_channel,
reader_and_node_list,
typ=type_hint,
)
)
for idx in task.downstream_task_idxs:
downstream_task = self.idx_to_task[idx]
Expand Down Expand Up @@ -1229,14 +1279,30 @@ def _get_or_compile(

# Create executable tasks for each actor
for actor_handle, tasks in self.actor_to_tasks.items():
# Dict from arg to the set of tasks that consume it.
# Dict from the arg to the set of tasks that consume it.
arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set)
# The number of tasks that consume InputNode (or InputAttributeNode)
# Note that _preprocess() ensures that all tasks either use InputNode
# or use InputAttributeNode, but not both.
num_input_consumers = 0

# Step 1: populate `arg_to_consumers` and perform some validation.
# Step 1: populate num_channel_reads and perform some validation.
for task in tasks:
has_at_least_one_channel_input = False
for arg in task.args:
if isinstance(arg, DAGNode):
if isinstance(arg, InputNode):
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, InputAttributeNode):
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
num_input_consumers = max(
num_input_consumers, len(arg_to_consumers[arg])
)
elif isinstance(arg, DAGNode): # Other DAGNodes
has_at_least_one_channel_input = True
arg_to_consumers[arg].add(task)
arg_idx = self.dag_node_to_idx[arg]
Expand All @@ -1259,29 +1325,44 @@ def _get_or_compile(
# created CachedChannel (if the original channel is read more than once).
channel_dict: Dict[ChannelInterface, ChannelInterface] = {}
for arg, consumers in arg_to_consumers.items():
arg_idx = self.dag_node_to_idx[arg]
upstream_task = self.idx_to_task[arg_idx]
assert len(upstream_task.output_channels) == 1
arg_channel = upstream_task.output_channels[0]
assert arg_channel is not None
if len(consumers) > 1:
channel_dict[arg_channel] = CachedChannel(
len(consumers),
arg_channel,
)
else:
channel_dict[arg_channel] = arg_channel
# Handle non-input args
if not isinstance(arg, InputNode) and not isinstance(
arg, InputAttributeNode
):
arg_idx = self.dag_node_to_idx[arg]
arg_channel = self.idx_to_task[arg_idx].output_channel
if len(consumers) > 1:
channel_dict[arg_channel] = CachedChannel(
len(consumers),
arg_channel,
)
else:
channel_dict[arg_channel] = arg_channel
# Handle input args
if num_input_consumers > 1:
channel_dict[self.dag_input_channel] = CachedChannel(
num_input_consumers,
self.dag_input_channel,
)
else:
channel_dict[self.dag_input_channel] = self.dag_input_channel

# Step 3: create executable tasks for the actor
executable_tasks = []
for task in tasks:
resolved_args: List[Any] = []
resolved_args = []
for arg in task.args:
if isinstance(arg, DAGNode):
if isinstance(arg, InputNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(None, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, InputAttributeNode):
input_channel = channel_dict[self.dag_input_channel]
input_adapter = DAGInputAdapter(arg, input_channel)
resolved_args.append(input_adapter)
elif isinstance(arg, DAGNode): # Other DAGNodes
arg_idx = self.dag_node_to_idx[arg]
upstream_task = self.idx_to_task[arg_idx]
assert len(upstream_task.output_channels) == 1
arg_channel = upstream_task.output_channels[0]
arg_channel = self.idx_to_task[arg_idx].output_channel
assert arg_channel is not None
arg_channel = channel_dict[arg_channel]
resolved_args.append(arg_channel)
Expand Down
Loading

0 comments on commit 785a594

Please sign in to comment.