Skip to content

Commit

Permalink
[Core][aDag] Support multi node multi reader (ray-project#47480)
Browse files Browse the repository at this point in the history
This PR supports multi readers in multi nodes. It also adds tests that the feature works with large gRPC payloads and buffer resizing.

multi readers in multi node didn't work because the code allows to only register 1 remote reader reference on 1 specific node. This fixes the issues by allowing to register remote reader references in multi nodes.

Signed-off-by: ujjawal-khare <[email protected]>
  • Loading branch information
rkooo567 authored and ujjawal-khare committed Oct 15, 2024
1 parent 8d20388 commit 6625ee2
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 188 deletions.
1 change: 0 additions & 1 deletion python/ray/dag/compiled_dag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ def do_exec_tasks(
if done:
break
for operation in schedule:
print("SANG-TODO operation: ", operation)
done = tasks[operation.exec_task_idx].exec_operation(
self, operation.type
)
Expand Down
55 changes: 52 additions & 3 deletions python/ray/dag/tests/experimental/test_multi_node_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,10 @@ def get_node_id(self):

compiled_dag = dag.experimental_compile()

# Ray sets the gRPC payload max size to 512 MiB. We choose a size in this test that
# is a bit larger.
size = GRPC_MAX_SIZE + (1024 * 1024 * 2)
val = b"x" * size

for i in range(3):
print(f"{i} iteration")
ref = compiled_dag.execute(val)
result = ray.get(ref)
assert result == val
Expand All @@ -249,6 +246,58 @@ def get_node_id(self):
compiled_dag.teardown()


@pytest.mark.parametrize("num_actors", [1, 4])
@pytest.mark.parametrize("num_nodes", [1, 4])
def test_multi_node_multi_reader_large_payload(
ray_start_cluster, num_actors, num_nodes, monkeypatch
):
# Set max grpc size to 5mb.
GRPC_MAX_SIZE = 1024 * 1024 * 5
monkeypatch.setenv("RAY_max_grpc_message_size", str(GRPC_MAX_SIZE))
cluster = ray_start_cluster
ACTORS_PER_NODE = num_actors
NUM_REMOTE_NODES = num_nodes
cluster.add_node(num_cpus=ACTORS_PER_NODE)
ray.init(address=cluster.address)
# This node is for the other two readers.
for _ in range(NUM_REMOTE_NODES):
cluster.add_node(num_cpus=ACTORS_PER_NODE)
cluster.wait_for_nodes()

wait_for_condition(lambda: len(ray.nodes()) == NUM_REMOTE_NODES + 1)

actors = [
Actor.options(num_cpus=1).remote(0)
for _ in range(ACTORS_PER_NODE * (NUM_REMOTE_NODES + 1))
]

def _get_node_id(self) -> "ray.NodeID":
return ray.get_runtime_context().get_node_id()

node_ids = ray.get([act.__ray_call__.remote(_get_node_id) for act in actors])
assert len(set(node_ids)) == NUM_REMOTE_NODES + 1

with InputNode() as inp:
outputs = []
for actor in actors:
outputs.append(actor.echo.bind(inp))
dag = MultiOutputNode(outputs)

compiled_dag = dag.experimental_compile()

# Set the object size a little bigger than the gRPC size so that
# it triggers chunking and resizing.
size = GRPC_MAX_SIZE + (1024 * 1024 * 2)
val = b"x" * size

for _ in range(3):
ref = compiled_dag.execute(val)
result = ray.get(ref)
assert result == [val for _ in range(ACTORS_PER_NODE * (NUM_REMOTE_NODES + 1))]

compiled_dag.teardown()


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
96 changes: 0 additions & 96 deletions python/ray/experimental/channel/shared_memory_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,102 +631,6 @@ def next_read_index(self):
return self._next_read_index


@DeveloperAPI
class BufferedSharedMemoryChannel(ChannelInterface):
"""A channel that can be read and written by Ray processes.
It creates `num_shm_buffers` number of buffers and allows buffered read and
write APIs. I.e., read and write APIs are non-blocking as long as it can write to
next buffer or read from a next buffer. See `read` and `write` APIs for
more details.
Args:
writer: The actor that may write to the channel. None signifies the driver.
reader_and_node_list: A list of tuples, where each tuple contains a reader
actor handle and the node ID where the actor is located.
num_shm_buffers: Number of shared memory buffers to read/write.
typ: Type information about the values passed through the channel.
Either an integer representing the max buffer size in bytes
allowed, or a SharedMemoryType.
"""

def __init__(
self,
writer: Optional[ray.actor.ActorHandle],
reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
num_shm_buffers: int,
typ: Optional[Union[int, SharedMemoryType]] = None,
):
self._num_shm_buffers = num_shm_buffers
self._buffers = [
# We use Channel directly as a buffer implementation as
# channel only allows to have 1 shared memory buffer.
Channel(writer, reader_and_node_list, typ)
for _ in range(num_shm_buffers)
]
# The next index to write from self._buffers.
self._next_write_index = 0
# The next index to read from self._buffers.
self._next_read_index = 0

def ensure_registered_as_writer(self):
"""
Check whether the process is a valid writer. This method must be idempotent.
"""
for buffer in self._buffers:
buffer.ensure_registered_as_writer()

def ensure_registered_as_reader(self):
"""
Check whether the process is a valid reader. This method must be idempotent.
"""
for buffer in self._buffers:
buffer.ensure_registered_as_reader()

def write(self, value: Any, timeout: Optional[float] = None) -> None:
"""Write a value to a channel.
If the next buffer is available, it returns immediately. If the next
buffer is not read by downstream consumers, it blocks until a buffer is
available to write. If a buffer is not available within timeout, it raises
RayChannelTimeoutError.
"""
# A single channel is not supposed to read and write at the same time.
assert self._next_read_index == 0
self._buffers[self._next_write_index].write(value, timeout)
self._next_write_index += 1
self._next_write_index %= self._num_shm_buffers

def read(self, timeout: Optional[float] = None) -> Any:
"""Read a value from a channel.
If the next buffer is available, it returns immediately. If the next
buffer is not written by an upstream producer, it blocks until a buffer is
available to read. If a buffer is not available within timeout, it raises
RayChannelTimeoutError.
"""
# A single channel is not supposed to read and write at the same time.
assert self._next_write_index == 0
output = self._buffers[self._next_read_index].read(timeout)
self._next_read_index += 1
self._next_read_index %= self._num_shm_buffers
return output

def close(self) -> None:
for buffer in self._buffers:
buffer.close()

@property
def next_write_index(self):
# Testing only
return self._next_write_index

@property
def next_read_index(self):
# Testing only
return self._next_read_index


@PublicAPI(stability="alpha")
class CompositeChannel(ChannelInterface):
"""
Expand Down
72 changes: 0 additions & 72 deletions python/ray/tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,78 +1394,6 @@ def write(self, i, timeout=None) -> bool:
)


def test_buffered_channel(shutdown_only):
"""Test buffered shared memory channel."""
BUFFER_SIZE = 5

@ray.remote(num_cpus=1)
class Actor:
def __init__(self):
self.write_index = 0

def setup(self, driver_actor):
self._channel = ray_channel.BufferedSharedMemoryChannel(
ray.get_runtime_context().current_actor,
[(driver_actor, get_actor_node_id(driver_actor))],
BUFFER_SIZE,
typ=1000,
)

def get_channel(self):
return self._channel

def write(self, i, timeout=None) -> bool:
"""Write to a channel Return False if channel times out.
Return true otherwise.
"""
self.write_index += 1
try:
self._channel.write(i, timeout)
except ray.exceptions.RayChannelTimeoutError:
return False
assert self._channel.next_write_index == self.write_index % BUFFER_SIZE
return True

a = Actor.remote()
ray.get(a.setup.remote(create_driver_actor()))
chan = ray.get(a.get_channel.remote())

print("Test basic.")
# Iterate more than buffer size to make sure it works over and over again.
read_idx = 0
for i in range(BUFFER_SIZE * 3):
read_idx += 1
assert ray.get(a.write.remote(i))
assert chan.read() == i
assert chan.next_read_index == read_idx % BUFFER_SIZE

print("Test Write timeout.")
# Test write timeout.
for i in range(BUFFER_SIZE):
# fill the buffer withtout read.
ray.get(a.write.remote(i))
# timeout because all the buffer is exhausted without being consumed.
assert ray.get(a.write.remote(1, timeout=1)) is False

print("Test Read timeout.")
# Test read timeout.
for i in range(BUFFER_SIZE):
# This reads all previous writes.
assert chan.read() == i
# This read times out because there's no new write, and the call blocks.
with pytest.raises(ray.exceptions.RayChannelTimeoutError):
chan.read(timeout=1)

print("Test serialization/deserialization works")
deserialized = pickle.loads(pickle.dumps(chan))
assert len(chan._buffers) == len(deserialized._buffers)
for i in range(len(chan._buffers)):
assert (
deserialized._buffers[i]._writer._actor_id
== chan._buffers[i]._writer._actor_id
)


if __name__ == "__main__":
if os.environ.get("PARALLEL_CI"):
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))
Expand Down
39 changes: 25 additions & 14 deletions src/ray/core_worker/experimental_mutable_object_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,20 +234,31 @@ void MutableObjectProvider::PollWriterClosure(
RAY_CHECK(object->GetData());
RAY_CHECK(object->GetMetadata());

RAY_LOG(ERROR) << "SANG-TODO Push mutable object! " << object_id;
reader->PushMutableObject(
object_id,
object->GetData()->Size(),
object->GetMetadata()->Size(),
object->GetData()->Data(),
[this, &io_context, object_id, reader](const Status &status,
const rpc::PushMutableObjectReply &reply) {
io_context.post(
[this, &io_context, object_id, reader]() {
PollWriterClosure(io_context, object_id, reader);
},
"experimental::MutableObjectProvider.PollWriter");
});
std::shared_ptr<size_t> num_replied = std::make_shared<size_t>(0);
for (const auto &reader : *remote_readers) {
reader->PushMutableObject(
writer_object_id,
object->GetData()->Size(),
object->GetMetadata()->Size(),
object->GetData()->Data(),
[this, &io_context, writer_object_id, remote_readers, num_replied](
const Status &status, const rpc::PushMutableObjectReply &reply) {
*num_replied += 1;
if (!status.ok()) {
RAY_LOG(ERROR)
<< "Failed to transfer object to a remote node for an object id "
<< writer_object_id << ". It can cause hang.";
}

if (*num_replied == remote_readers->size()) {
io_context.post(
[this, &io_context, writer_object_id, remote_readers]() {
PollWriterClosure(io_context, writer_object_id, remote_readers);
},
"experimental::MutableObjectProvider.PollWriter");
}
});
}
}

void MutableObjectProvider::RunIOContext(instrumented_io_context &io_context) {
Expand Down
3 changes: 1 addition & 2 deletions src/ray/raylet_client/raylet_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,7 @@ void raylet::RayletClient::PushMutableObject(
const ray::rpc::ClientCallback<ray::rpc::PushMutableObjectReply> &callback) {
// Ray sets the gRPC max payload size to ~512 MiB. We set the max chunk size to a
// slightly lower value to allow extra padding just in case.
uint64_t kMaxGrpcPayloadSize =
RayConfig::instance().max_grpc_message_size() * 0.98; // 500 MiB.
uint64_t kMaxGrpcPayloadSize = RayConfig::instance().max_grpc_message_size() * 0.98;
uint64_t total_size = data_size + metadata_size;
uint64_t total_num_chunks = total_size / kMaxGrpcPayloadSize;
// If `total_size` is not a multiple of `kMaxGrpcPayloadSize`, then we need to send an
Expand Down

0 comments on commit 6625ee2

Please sign in to comment.