Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add shm broadcast #5399

Merged
merged 43 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
c82dc08
add shm broadcast
youkaichao Jun 10, 2024
aa15dc2
fix name
youkaichao Jun 10, 2024
7d4f81e
Merge branch 'main' into shm_broadcast
youkaichao Jun 14, 2024
b4e5d47
enable shm broadcast
youkaichao Jun 14, 2024
aa606d1
use HIGHEST_PROTOCOL
youkaichao Jun 14, 2024
943147c
add modulus
youkaichao Jun 14, 2024
2f2e7b8
error on large object
youkaichao Jun 14, 2024
399f8c1
name written flag
youkaichao Jun 14, 2024
a7db4d5
rename to data and metadata
youkaichao Jun 14, 2024
a105688
add sleep if all blocks are empty
youkaichao Jun 14, 2024
d09c8b6
bump up slots
youkaichao Jun 14, 2024
ba6839d
only memset for metadata section
youkaichao Jun 14, 2024
a298ae9
add comments
youkaichao Jun 14, 2024
2c775d0
remove initialization in world size 1
youkaichao Jun 15, 2024
b8105bb
add comments
youkaichao Jun 15, 2024
681919a
add warning if waiting for too long
youkaichao Jun 15, 2024
468bf93
add shm broadcast tests
youkaichao Jun 15, 2024
c5e47b3
lint
youkaichao Jun 15, 2024
98188ce
add tests
youkaichao Jun 15, 2024
8e755f2
Update vllm/distributed/device_communicators/shm_broadcast.py
youkaichao Jun 16, 2024
5197920
Merge branch 'main' into shm_broadcast
youkaichao Jun 16, 2024
398c6e2
Merge branch 'main' into shm_broadcast
youkaichao Jun 16, 2024
57a1839
Merge branch 'main' into shm_broadcast
youkaichao Jun 18, 2024
95b8a87
use underscore for private attributes
youkaichao Jun 18, 2024
c0cc37f
rename
youkaichao Jun 18, 2024
fc49f86
add mem layout docstring
youkaichao Jun 18, 2024
34475a0
stash
youkaichao Jun 18, 2024
82792f1
Merge branch 'main' into shm_broadcast
youkaichao Jun 20, 2024
4b70d6f
refactor
youkaichao Jun 20, 2024
bb851d4
use queue
youkaichao Jun 20, 2024
f7680f4
fix lint
youkaichao Jun 20, 2024
9af386c
add single process test
youkaichao Jun 20, 2024
729a592
fix warning
youkaichao Jun 20, 2024
0a61a69
add barrier
youkaichao Jun 20, 2024
d8d9a0f
add test for complicated cases
youkaichao Jun 20, 2024
608d57f
fix tests
youkaichao Jun 20, 2024
e5137cb
fix tests
youkaichao Jun 20, 2024
d0aa190
add comments
youkaichao Jun 21, 2024
d0522b0
fix race condition
youkaichao Jun 21, 2024
cd39b81
add random delay in test
youkaichao Jun 21, 2024
d0f77e9
add comments
youkaichao Jun 21, 2024
5fd104e
use env var
youkaichao Jun 21, 2024
0e3a810
fix env
youkaichao Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ steps:

- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
command: pytest -v -s distributed/test_comm_ops.py
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- pytest -v -s distributed/test_comm_ops.py
- pytest -v -s distributed/test_shm_broadcast.py

- label: Distributed Tests (2 GPUs)
mirror_hardwares: [amd]
Expand Down
82 changes: 82 additions & 0 deletions tests/distributed/test_shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import multiprocessing
import random
import time

import torch.distributed as dist

from vllm.distributed.device_communicators.shm_broadcast import (
ShmRingBuffer, ShmRingBufferIO)
from vllm.utils import update_environment_variables


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = {}
env['RANK'] = str(i)
env['LOCAL_RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()

for p in processes:
p.join()

for p in processes:
assert p.exitcode == 0


def worker_fn_wrapper(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this create side effects to rest of the tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is local to the process created in this test.

dist.init_process_group(backend="gloo")
fn()

return wrapped_fn


@worker_fn_wrapper
def worker_fn():
writer_rank = 2
broadcaster = ShmRingBufferIO.create_from_process_group(
dist.group.WORLD, 1024, 2, writer_rank)
if dist.get_rank() == writer_rank:
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
time.sleep(random.random())
broadcaster.broadcast_object(0)
time.sleep(random.random())
broadcaster.broadcast_object({})
time.sleep(random.random())
broadcaster.broadcast_object([])
else:
time.sleep(random.random())
a = broadcaster.broadcast_object(None)
time.sleep(random.random())
b = broadcaster.broadcast_object(None)
time.sleep(random.random())
c = broadcaster.broadcast_object(None)
assert a == 0
assert b == {}
assert c == []
dist.barrier()


def test_shm_broadcast():
distributed_run(worker_fn, 4)


def test_singe_process():
buffer = ShmRingBuffer(1, 1024, 4)
reader = ShmRingBufferIO(buffer, reader_rank=0)
writer = ShmRingBufferIO(buffer, reader_rank=-1)
writer.enqueue([0])
writer.enqueue([1])
assert reader.dequeue() == [0]
assert reader.dequeue() == [1]
259 changes: 259 additions & 0 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import pickle
import time
from contextlib import contextmanager
from multiprocessing import shared_memory
from typing import Optional
from unittest.mock import patch

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

import vllm.envs as envs
from vllm.logger import init_logger

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

logger = init_logger(__name__)


class ShmRingBuffer:

def __init__(self,
n_reader: int,
max_chunk_bytes: int,
max_chunks: int,
name: Optional[str] = None):
"""
A shared memory ring buffer implementation for broadcast communication.
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
Essentially, it is a queue where only one will `enqueue` and multiple
will `dequeue`. The max size of each item, together with the max number
of items that can be stored in the buffer are known in advance.
In this case, we don't need to synchronize the access to
the buffer.

Buffer memory layout:
data metadata
| |
| (current_idx) | (current_idx)
v v
+-------------------------------+----------------------------------------+
| chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
+-------------------------------+----------------------------------------+
| max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes |

metadata memory layout: each byte is a flag, the first byte is the written
flag, and the rest are reader flags. The flags are set to 0 by default.
+--------------+--------------+--------------+-----+--------------+
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+

During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
same shared memory buffer.
"""# noqa
self.n_reader = n_reader
self.metadata_size = 1 + n_reader
self.max_chunk_bytes = max_chunk_bytes
self.max_chunks = max_chunks
self.total_bytes_of_buffer = (self.max_chunk_bytes +
self.metadata_size) * self.max_chunks
self.data_offset = 0
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
self.metadata_offset = self.max_chunk_bytes * self.max_chunks

if name is None:
# we are creating a buffer
self.is_creator = True
self.shared_memory = shared_memory.SharedMemory(
create=True, size=self.total_bytes_of_buffer)
# initialize the metadata section to 0
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
else:
# we are opening an existing buffer
self.is_creator = False
# fix to https://stackoverflow.com/q/62748654/9191338
# Python incorrectly tracks shared memory even if it is not
# created by the process. The following patch is a workaround.
with patch("multiprocessing.resource_tracker.register",
lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name)
assert self.shared_memory.size == self.total_bytes_of_buffer
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
assert torch.all(tensor == 0)

def __reduce__(self):
return (
self.__class__,
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
self.shared_memory.name),
)

def __del__(self):
self.shared_memory.close()
if self.is_creator:
self.shared_memory.unlink()

@contextmanager
def get_data(self, current_idx: int):
start = self.data_offset + current_idx * self.max_chunk_bytes
end = start + self.max_chunk_bytes
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf

@contextmanager
def get_metadata(self, current_idx: int):
start = self.metadata_offset + current_idx * self.metadata_size
end = start + self.metadata_size
with memoryview(self.shared_memory.buf[start:end]) as buf:
yield buf


class ShmRingBufferIO:

def __init__(self, buffer: ShmRingBuffer, reader_rank: int):
self.buffer = buffer
self.reader_rank = reader_rank
self._is_writer = self.reader_rank == -1
self._is_reader = not self._is_writer
if self._is_reader:
assert 0 <= self.reader_rank < buffer.n_reader, \
(f"Invalid reader rank {self.reader_rank} for buffer"
f" created with {buffer.n_reader} readers")
self.current_idx = 0

@contextmanager
def acquire_write(self):
assert self._is_writer, "Only writers can acquire write"
start_index = self.current_idx
start_time = time.time()
n_warning = 1
while True:
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_count = sum(metadata_buffer[1:])
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
written_flag = metadata_buffer[0]
if written_flag and read_count != self.buffer.n_reader:
# this block is written and not read by all readers
# try to write to the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means we iterate all the data chunks right? Can you comment here?

# no empty block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this too small? (we call sleep 1e7 times per second). Have you measured the CPU overhead in this scenario by chance?

time.sleep(1e-7)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
continue
# found a block that is either
# (1) not written
# (2) read by all readers

# mark the block as not written
metadata_buffer[0] = 0
# let caller write to the buffer
with self.buffer.get_data(self.current_idx) as buf:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Looks like for this function, it is possible we can write to the next chunk if previous chunk was not fully read yet. Doesn't this mean the ordering can be screwed up? For example, let's say you have 3 chunks, A, B, and C. You write to all 3 chunks and enqueue the next item. B is read by all readers before A is read. In this case, new write starts from B. And then let's say A is read before C is read. if so, it hangs forever.

Is this possible, or is the read order somehow guaranteed by acquire_read?

yield buf

# caller has written to the buffer
# mark the block as written
metadata_buffer[0] = 1
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
break

@contextmanager
def acquire_read(self):
assert self._is_reader, "Only readers can acquire read"
start_index = self.current_idx
start_time = time.time()
n_warning = 1
while True:
with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
read_flag = metadata_buffer[self.reader_rank + 1]
written_flag = metadata_buffer[0]
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
if not written_flag or read_flag:
# this block is either
# (1) not written
# (2) already read by this reader
# try to read the next block
self.current_idx = (self.current_idx +
1) % self.buffer.max_chunks
if self.current_idx == start_index:
# no block found
if time.time(
) - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning: # noqa
logger.warning(
"No available block found in %s second. ",
VLLM_RINGBUFFER_WARNING_INTERVAL)
n_warning += 1
# wait for a while (0.1 us)
time.sleep(1e-7)
continue
# found a block that is not read by this reader
# let caller read from the buffer
with self.buffer.get_data(self.current_idx) as buf:
yield buf

# caller has read from the buffer
# set the read flag
metadata_buffer[self.reader_rank + 1] = 1
break

def enqueue(self, obj):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add docstring this is blocking if there's already an item in the queue (that's a little different semantic from normal "enqueue" which returns although there's an active item in the queue)

assert self._is_writer, "Only writers can enqueue"
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
if len(serialized_obj) > self.buffer.max_chunk_bytes:
raise RuntimeError(
f"{len(serialized_obj)=} larger than the allowed value "
f"{self.buffer.max_chunk_bytes},"
"Please increase the max_chunk_bytes parameter.")
with self.acquire_write() as buf:
buf[:len(serialized_obj)] = serialized_obj

def dequeue(self):
assert self._is_reader, "Only readers can dequeue"
with self.acquire_read() as buf:
# no need to know the size of serialized object
# pickle format itself contains the size information internally
# see https://docs.python.org/3/library/pickle.html
obj = pickle.loads(buf)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
return obj

def broadcast_object(self, obj=None):
if self._is_writer:
self.enqueue(obj)
return obj
else:
return self.dequeue()

def create_from_process_group(pg: ProcessGroup,
max_chunk_bytes,
max_chunks,
writer_rank=0) -> "ShmRingBufferIO":
group_rank = dist.get_rank(pg)
group_world_size = dist.get_world_size(pg)
ranks_inside_group = list(range(group_world_size))
global_ranks = dist.get_process_group_ranks(pg)
n_reader = group_world_size - 1
buffer: ShmRingBuffer
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
dist.barrier(pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
dist.barrier(pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))
Loading
Loading