diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index a3b588fd2912..68025c802a09 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -8,6 +8,7 @@ import traceback from typing import NamedTuple +from ray.experimental.channel.cached_channel import CachedChannel import ray from ray.exceptions import RayTaskError, RayChannelError from ray.util.annotations import PublicAPI @@ -247,13 +248,14 @@ def __init__( self.input_variant = input_variant self.channel_idx = channel_idx - def resolve(self, channel_results: Any): + def resolve(self, channel_results: Any) -> Any: """ Resolve the input value from the channel results. Args: channel_results: The results from reading the input channels. """ + if isinstance(self.input_variant, ChannelInterface): value = channel_results[self.channel_idx] elif isinstance(self.input_variant, DAGInputAdapter): @@ -455,7 +457,11 @@ def _write(self) -> bool: exit = True return exit - def exec_operation(self, class_handle, op_type: _DAGNodeOperationType) -> bool: + def exec_operation( + self, + class_handle, + op_type: _DAGNodeOperationType, + ) -> bool: """ An ExecutableTask corresponds to a DAGNode. It consists of three operations: READ, COMPUTE, and WRITE, which should be executed in @@ -1023,12 +1029,17 @@ def _get_or_compile( ) ) else: + # Use reader_handles_set to deduplicate readers on the same actor, + # because with CachedChannel each actor will only read from the + # upstream channel once. + reader_handles_set = set() for reader in readers: reader_handle = reader.dag_node._get_actor_handle() - reader_and_node_list.append( - (reader_handle, self._get_node_id(reader_handle)) - ) - + if reader_handle not in reader_handles_set: + reader_and_node_list.append( + (reader_handle, self._get_node_id(reader_handle)) + ) + reader_handles_set.add(reader_handle) fn = task.dag_node._get_remote_method("__ray_call__") task.output_channel = ray.get( fn.remote( @@ -1110,27 +1121,35 @@ def _get_or_compile( # Create executable tasks for each actor for actor_handle, tasks in self.actor_to_tasks.items(): - executable_tasks = [] + # Dict from the arg to the set of tasks that consume it. + arg_to_consumers: Dict[DAGNode, Set[CompiledTask]] = defaultdict(set) + # The number of tasks that consume InputNode (or InputAttributeNode) + # Note that _preprocess() ensures that all tasks either use InputNode + # or use InputAttributeNode, but not both. + num_input_consumers = 0 + + # Step 1: populate num_channel_reads and perform some validation. for task in tasks: - resolved_args = [] has_at_least_one_channel_input = False for arg in task.args: if isinstance(arg, InputNode): - input_adapter = DAGInputAdapter(None, self.dag_input_channel) - resolved_args.append(input_adapter) has_at_least_one_channel_input = True + arg_to_consumers[arg].add(task) + num_input_consumers = max( + num_input_consumers, len(arg_to_consumers[arg]) + ) elif isinstance(arg, InputAttributeNode): - input_adapter = DAGInputAdapter(arg, self.dag_input_channel) - resolved_args.append(input_adapter) has_at_least_one_channel_input = True + arg_to_consumers[arg].add(task) + num_input_consumers = max( + num_input_consumers, len(arg_to_consumers[arg]) + ) elif isinstance(arg, DAGNode): # Other DAGNodes + has_at_least_one_channel_input = True + arg_to_consumers[arg].add(task) arg_idx = self.dag_node_to_idx[arg] arg_channel = self.idx_to_task[arg_idx].output_channel assert arg_channel is not None - resolved_args.append(arg_channel) - has_at_least_one_channel_input = True - else: - resolved_args.append(arg) # TODO: Support no-input DAGs (use an empty object to signal). if not has_at_least_one_channel_input: raise ValueError( @@ -1138,6 +1157,58 @@ def _get_or_compile( "ray.dag.InputNode or at least one other DAGNode as an " "input" ) + + # Step 2: create cached channels if needed + + # Dict from original channel to the channel to be used in execution. + # The value of this dict is either the original channel or a newly + # created CachedChannel (if the original channel is read more than once). + channel_dict: Dict[ChannelInterface, ChannelInterface] = {} + for arg, consumers in arg_to_consumers.items(): + # Handle non-input args + if not isinstance(arg, InputNode) and not isinstance( + arg, InputAttributeNode + ): + arg_idx = self.dag_node_to_idx[arg] + arg_channel = self.idx_to_task[arg_idx].output_channel + if len(consumers) > 1: + channel_dict[arg_channel] = CachedChannel( + len(consumers), + arg_channel, + ) + else: + channel_dict[arg_channel] = arg_channel + # Handle input args + if num_input_consumers > 1: + channel_dict[self.dag_input_channel] = CachedChannel( + num_input_consumers, + self.dag_input_channel, + ) + else: + channel_dict[self.dag_input_channel] = self.dag_input_channel + + # Step 3: create executable tasks for the actor + executable_tasks = [] + for task in tasks: + resolved_args = [] + for arg in task.args: + if isinstance(arg, InputNode): + input_channel = channel_dict[self.dag_input_channel] + input_adapter = DAGInputAdapter(None, input_channel) + resolved_args.append(input_adapter) + elif isinstance(arg, InputAttributeNode): + input_channel = channel_dict[self.dag_input_channel] + input_adapter = DAGInputAdapter(arg, input_channel) + resolved_args.append(input_adapter) + elif isinstance(arg, DAGNode): # Other DAGNodes + arg_idx = self.dag_node_to_idx[arg] + arg_channel = self.idx_to_task[arg_idx].output_channel + assert arg_channel is not None + arg_channel = channel_dict[arg_channel] + resolved_args.append(arg_channel) + else: + # Constant arg + resolved_args.append(arg) executable_task = ExecutableTask( task, resolved_args, diff --git a/python/ray/dag/tests/experimental/test_accelerated_dag.py b/python/ray/dag/tests/experimental/test_accelerated_dag.py index ae76ed8db3d5..812a6b359129 100644 --- a/python/ray/dag/tests/experimental/test_accelerated_dag.py +++ b/python/ray/dag/tests/experimental/test_accelerated_dag.py @@ -267,6 +267,82 @@ def test_actor_method_bind_same_constant(ray_start_regular): compiled_dag.teardown() +def test_actor_method_bind_same_input(ray_start_regular): + actor = Actor.remote(0) + with InputNode() as inp: + # Test binding input node to the same method + # of same actor multiple times: execution + # should not hang. + output1 = actor.inc.bind(inp) + output2 = actor.inc.bind(inp) + dag = MultiOutputNode([output1, output2]) + compiled_dag = dag.experimental_compile() + expected = [[0, 0], [1, 2], [4, 6]] + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == expected[i] + compiled_dag.teardown() + + +def test_actor_method_bind_same_input_attr(ray_start_regular): + actor = Actor.remote(0) + with InputNode() as inp: + # Test binding input attribute node to the same method + # of same actor multiple times: execution should not + # hang. + output1 = actor.inc.bind(inp[0]) + output2 = actor.inc.bind(inp[0]) + dag = MultiOutputNode([output1, output2]) + compiled_dag = dag.experimental_compile() + expected = [[0, 0], [1, 2], [4, 6]] + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == expected[i] + compiled_dag.teardown() + + +def test_actor_method_bind_same_arg(ray_start_regular): + a1 = Actor.remote(0) + a2 = Actor.remote(0) + with InputNode() as inp: + # Test binding arg to the same method + # of same actor multiple times: execution + # should not hang. + output1 = a1.echo.bind(inp) + output2 = a2.inc.bind(output1) + output3 = a2.inc.bind(output1) + dag = MultiOutputNode([output2, output3]) + compiled_dag = dag.experimental_compile() + expected = [[0, 0], [1, 2], [4, 6]] + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == expected[i] + compiled_dag.teardown() + + +def test_mixed_bind_same_input(ray_start_regular): + a1 = Actor.remote(0) + a2 = Actor.remote(0) + with InputNode() as inp: + # Test binding input node to the same method + # of different actors multiple times: execution + # should not hang. + output1 = a1.inc.bind(inp) + output2 = a1.inc.bind(inp) + output3 = a2.inc.bind(inp) + dag = MultiOutputNode([output1, output2, output3]) + compiled_dag = dag.experimental_compile() + expected = [[0, 0, 0], [1, 2, 1], [4, 6, 3]] + for i in range(3): + ref = compiled_dag.execute(i) + result = ray.get(ref) + assert result == expected[i] + compiled_dag.teardown() + + def test_regular_args(ray_start_regular): # Test passing regular args to .bind in addition to DAGNode args. a = Actor.remote(0) @@ -304,14 +380,15 @@ def test_multi_args_basic(ray_start_regular): def test_multi_args_single_actor(ray_start_regular): c = Collector.remote() with InputNode() as i: - dag = c.collect_two.bind(i[1], i[0]) + dag = c.collect_three.bind(i[0], i[1], i[0]) compiled_dag = dag.experimental_compile() + expected = [[0, 1, 0], [0, 1, 0, 1, 2, 1], [0, 1, 0, 1, 2, 1, 2, 3, 2]] for i in range(3): - ref = compiled_dag.execute(2, 3) + ref = compiled_dag.execute(i, i + 1) result = ray.get(ref) - assert result == [3, 2] * (i + 1) + assert result == expected[i] with pytest.raises( ValueError, @@ -1424,6 +1501,24 @@ def get_node_id(self): compiled_dag.teardown() +@ray.remote +class TestWorker: + def add_one(self, value): + return value + 1 + + def add(self, val1, val2): + return val1 + val2 + + def generate_torch_tensor(self, size) -> torch.Tensor: + return torch.zeros(size) + + def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor: + """ + Add `value` to all elements of the tensor. + """ + return tensor + value + + class TestActorInputOutput: """ Accelerated DAGs support the following two cases for the input/output of the graph: @@ -1436,23 +1531,6 @@ class TestActorInputOutput: which is an actor, needs to be the input and output of the graph. """ - @ray.remote - class Worker: - def add_one(self, value): - return value + 1 - - def add(self, val1, val2): - return val1 + val2 - - def generate_torch_tensor(self, size) -> torch.Tensor: - return torch.zeros(size) - - def add_value_to_tensor(self, value: int, tensor: torch.Tensor) -> torch.Tensor: - """ - Add `value` to all elements of the tensor. - """ - return tensor + value - def test_shared_memory_channel_only(ray_start_cluster): """ Replica -> Worker -> Replica @@ -1463,7 +1541,7 @@ def test_shared_memory_channel_only(ray_start_cluster): @ray.remote class Replica: def __init__(self): - self.w = TestActorInputOutput.Worker.remote() + self.w = TestWorker.remote() with InputNode() as inp: dag = self.w.add_one.bind(inp) self.compiled_dag = dag.experimental_compile() @@ -1487,7 +1565,7 @@ def test_intra_process_channel(ray_start_cluster): @ray.remote class Replica: def __init__(self): - self.w = TestActorInputOutput.Worker.remote() + self.w = TestWorker.remote() with InputNode() as inp: dag = self.w.add_one.bind(inp) dag = self.w.add_one.bind(dag) @@ -1512,8 +1590,8 @@ def test_multiple_readers_multiple_writers(ray_start_cluster): @ray.remote class Replica: def __init__(self): - w1 = TestActorInputOutput.Worker.remote() - w2 = TestActorInputOutput.Worker.remote() + w1 = TestWorker.remote() + w2 = TestWorker.remote() with InputNode() as inp: dag = MultiOutputNode([w1.add_one.bind(inp), w2.add_one.bind(inp)]) self.compiled_dag = dag.experimental_compile() @@ -1539,8 +1617,8 @@ def test_multiple_readers_single_writer(ray_start_cluster): @ray.remote class Replica: def __init__(self): - w1 = TestActorInputOutput.Worker.remote() - w2 = TestActorInputOutput.Worker.remote() + w1 = TestWorker.remote() + w2 = TestWorker.remote() with InputNode() as inp: branch1 = w1.add_one.bind(inp) branch2 = w2.add_one.bind(inp) @@ -1567,8 +1645,8 @@ def test_single_reader_multiple_writers(ray_start_cluster): @ray.remote class Replica: def __init__(self): - w1 = TestActorInputOutput.Worker.remote() - w2 = TestActorInputOutput.Worker.remote() + w1 = TestWorker.remote() + w2 = TestWorker.remote() with InputNode() as inp: dag = w1.add_one.bind(inp) dag = MultiOutputNode([w1.add_one.bind(dag), w2.add_one.bind(dag)]) @@ -1592,8 +1670,8 @@ def test_torch_tensor_type(ray_start_cluster): @ray.remote class Replica: def __init__(self): - self._base = TestActorInputOutput.Worker.remote() - self._refiner = TestActorInputOutput.Worker.remote() + self._base = TestWorker.remote() + self._refiner = TestWorker.remote() with ray.dag.InputNode() as inp: dag = self._refiner.add_value_to_tensor.bind( diff --git a/python/ray/experimental/channel/__init__.py b/python/ray/experimental/channel/__init__.py index 76d75c938026..bcca146bd8f6 100644 --- a/python/ray/experimental/channel/__init__.py +++ b/python/ray/experimental/channel/__init__.py @@ -1,3 +1,4 @@ +from ray.experimental.channel.cached_channel import CachedChannel from ray.experimental.channel.common import ( # noqa: F401 AwaitableBackgroundReader, AwaitableBackgroundWriter, @@ -16,6 +17,7 @@ __all__ = [ "AwaitableBackgroundReader", "AwaitableBackgroundWriter", + "CachedChannel", "Channel", "ReaderInterface", "SynchronousReader", diff --git a/python/ray/experimental/channel/cached_channel.py b/python/ray/experimental/channel/cached_channel.py new file mode 100644 index 000000000000..9f7feb4bdab7 --- /dev/null +++ b/python/ray/experimental/channel/cached_channel.py @@ -0,0 +1,109 @@ +import uuid +from typing import Any, Optional + +from ray.experimental.channel.common import ChannelInterface + + +class CachedChannel(ChannelInterface): + """ + CachedChannel wraps an inner channel and caches the data read from it until + `num_reads` reads have completed. If inner channel is None, the data + is written to serialization context and retrieved from there. This is useful + when passing data within the same actor and a shared memory channel can be + avoided. + + Args: + num_reads: The number of reads from this channel that must happen before + writing again. Readers must be methods of the same actor. + inner_channel: The inner channel to cache data from. If None, the data is + read from the serialization context. + _channel_id: The unique ID for the channel. If None, a new ID is generated. + """ + + def __init__( + self, + num_reads: int, + inner_channel: Optional[ChannelInterface] = None, + _channel_id: Optional[str] = None, + ): + assert num_reads > 0, "num_reads must be greater than 0." + self._num_reads = num_reads + self._inner_channel = inner_channel + # Generate a unique ID for the channel. The writer and reader will use + # this ID to store and retrieve data from the _SerializationContext. + self._channel_id = _channel_id + if self._channel_id is None: + self._channel_id = str(uuid.uuid4()) + + def ensure_registered_as_writer(self) -> None: + if self._inner_channel is not None: + self._inner_channel.ensure_registered_as_writer() + + def ensure_registered_as_reader(self) -> None: + if self._inner_channel is not None: + self._inner_channel.ensure_registered_as_reader() + + def __reduce__(self): + return CachedChannel, ( + self._num_reads, + self._inner_channel, + self._channel_id, + ) + + def __str__(self) -> str: + return ( + f"CachedChannel(channel_id={self._channel_id}, " + f"num_reads={self._num_reads}), " + f"inner_channel={self._inner_channel})" + ) + + def write(self, value: Any, timeout: Optional[float] = None): + # TODO: beter organize the imports + from ray.experimental.channel import ChannelContext + + if self._inner_channel is not None: + self._inner_channel.write(value, timeout) + return + + # Otherwise no need to check timeout as the operation is non-blocking. + + # Because both the reader and writer are in the same worker process, + # we can directly store the data in the context instead of storing + # it in the channel object. This removes the serialization overhead of `value`. + ctx = ChannelContext.get_current().serialization_context + ctx.set_data(self._channel_id, value, self._num_reads) + + def read(self, timeout: Optional[float] = None) -> Any: + # TODO: beter organize the imports + from ray.experimental.channel import ChannelContext + + ctx = ChannelContext.get_current().serialization_context + if ctx.has_data(self._channel_id): + # No need to check timeout as the operation is non-blocking. + return ctx.get_data(self._channel_id) + + assert ( + self._inner_channel is not None + ), "Cannot read from the serialization context while inner channel is None." + value = self._inner_channel.read(timeout) + ctx.set_data(self._channel_id, value, self._num_reads) + # NOTE: Currently we make a contract with aDAG users that the + # channel results should not be mutated by the actor methods. + # When the user needs to modify the channel results, they should + # make a copy of the channel results and modify the copy. + # This is the same contract as used in IntraProcessChannel. + # This contract is NOT enforced right now in either case. + # TODO(rui): introduce a flag to control the behavior: + # for example, by default we make a deep copy of the channel + # result, but the user can turn off the deep copy for performance + # improvements. + # https://github.com/ray-project/ray/issues/47409 + return ctx.get_data(self._channel_id) + + def close(self) -> None: + from ray.experimental.channel import ChannelContext + + if self._inner_channel is not None: + self._inner_channel.close() + ctx = ChannelContext.get_current().serialization_context + ctx.reset_data(self._channel_id) diff --git a/python/ray/experimental/channel/serialization_context.py b/python/ray/experimental/channel/serialization_context.py index 92599a0f8ee8..7364599ed6b5 100644 --- a/python/ray/experimental/channel/serialization_context.py +++ b/python/ray/experimental/channel/serialization_context.py @@ -34,6 +34,9 @@ def set_data(self, channel_id: str, value: Any, num_readers: int) -> None: self.intra_process_channel_buffers[channel_id] = value self.channel_id_to_num_readers[channel_id] = num_readers + def has_data(self, channel_id: str) -> bool: + return channel_id in self.intra_process_channel_buffers + def get_data(self, channel_id: str) -> Any: assert ( channel_id in self.intra_process_channel_buffers diff --git a/python/ray/experimental/channel/shared_memory_channel.py b/python/ray/experimental/channel/shared_memory_channel.py index 58e18a531704..5dbee7859425 100644 --- a/python/ray/experimental/channel/shared_memory_channel.py +++ b/python/ray/experimental/channel/shared_memory_channel.py @@ -518,7 +518,7 @@ def __init__( remote_reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]] = [] for reader, node in self._reader_and_node_list: - if self._writer != reader: + if reader != self._writer: remote_reader_and_node_list.append((reader, node)) # There are some local readers which are the same worker process as the writer. # Create a local channel for the writer and the local readers. @@ -526,7 +526,10 @@ def __init__( remote_reader_and_node_list ) if num_local_readers > 0: - local_channel = IntraProcessChannel(num_local_readers) + # Use num_readers = 1 when creating the local channel, + # because we have channel cache to support reading + # from the same channel multiple times. + local_channel = IntraProcessChannel(num_readers=1) self._channels.add(local_channel) actor_id = self._get_actor_id(self._writer) self._channel_dict[actor_id] = local_channel diff --git a/python/ray/tests/test_channel.py b/python/ray/tests/test_channel.py index 4f0c1bc2d284..97e119d87df5 100644 --- a/python/ray/tests/test_channel.py +++ b/python/ray/tests/test_channel.py @@ -785,6 +785,102 @@ def get_ctx_buffer_size(self): ray.get(actor.write.remote("world hello")) +@pytest.mark.skipif( + sys.platform != "linux" and sys.platform != "darwin", + reason="Requires Linux or Mac.", +) +def test_cached_channel_single_reader(): + ray.init() + + @ray.remote + class Actor: + def __init__(self): + pass + + def pass_channel(self, channel): + self._chan = channel + + def read(self): + return self._chan.read() + + def get_ctx_buffer_size(self): + ctx = ray_channel.ChannelContext.get_current().serialization_context + return len(ctx.intra_process_channel_buffers) + + actor = Actor.remote() + inner_channel = ray_channel.Channel( + None, + [ + (actor, get_actor_node_id(actor)), + ], + 1000, + ) + channel = ray_channel.CachedChannel(num_reads=1, inner_channel=inner_channel) + ray.get(actor.pass_channel.remote(channel)) + + channel.write("hello") + assert ray.get(actor.read.remote()) == "hello" + + # The _SerializationContext should clean up the data after a read. + assert ray.get(actor.get_ctx_buffer_size.remote()) == 0 + + # Write again after reading num_readers times. + channel.write("world") + assert ray.get(actor.read.remote()) == "world" + + # The _SerializationContext should clean up the data after a read. + assert ray.get(actor.get_ctx_buffer_size.remote()) == 0 + + +@pytest.mark.skipif( + sys.platform != "linux" and sys.platform != "darwin", + reason="Requires Linux or Mac.", +) +def test_cached_channel_multi_readers(ray_start_cluster): + @ray.remote + class Actor: + def __init__(self): + pass + + def pass_channel(self, channel): + self._chan = channel + + def read(self): + return self._chan.read() + + def get_ctx_buffer_size(self): + ctx = ray_channel.ChannelContext.get_current().serialization_context + return len(ctx.intra_process_channel_buffers) + + actor = Actor.remote() + inner_channel = ray_channel.Channel( + None, + [ + (actor, get_actor_node_id(actor)), + ], + 1000, + ) + channel = ray_channel.CachedChannel(num_reads=2, inner_channel=inner_channel) + ray.get(actor.pass_channel.remote(channel)) + + channel.write("hello") + # first read + assert ray.get(actor.read.remote()) == "hello" + assert ray.get(actor.get_ctx_buffer_size.remote()) == 1 + # second read + assert ray.get(actor.read.remote()) == "hello" + assert ray.get(actor.get_ctx_buffer_size.remote()) == 0 + + # Write again after reading num_readers times. + channel.write("world") + # first read + assert ray.get(actor.read.remote()) == "world" + assert ray.get(actor.get_ctx_buffer_size.remote()) == 1 + # second read + assert ray.get(actor.read.remote()) == "world" + assert ray.get(actor.get_ctx_buffer_size.remote()) == 0 + + @pytest.mark.skipif( sys.platform != "linux" and sys.platform != "darwin", reason="Requires Linux or Mac.", @@ -865,7 +961,7 @@ def test_composite_channel_multiple_readers(ray_start_cluster): (1) The driver can write data to CompositeChannel and two actors can read it. (2) An actor can write data to CompositeChannel and another actor, as well as itself, can read it. - (3) An actor writes data to CompositeChannel and two Ray tasks on the same + (3) An actor writes data to CompositeChannel and two actor methods on the same actor read it. This is not supported and should raise an exception. """ # This node is for both the driver and the Reader actors. @@ -922,11 +1018,11 @@ def write(self, value): ) ray.get(actor1.write.remote("hello world")) assert ray.get(actor1.read.remote()) == "hello world" - assert ray.get(actor1.read.remote()) == "hello world" with pytest.raises(ray.exceptions.RayTaskError): - # actor1_output_channel has two readers, so it can only be read twice. - # The third read should raise an exception. + # actor1_output_channel can be read only once if the readers + # are the same actor. Note that reading the channel multiple + # times is supported via channel cache mechanism. ray.get(actor1.read.remote()) """ TODO (kevin85421): Add tests for the following cases: