-
-
Notifications
You must be signed in to change notification settings - Fork 4.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][Distributed] add shm broadcast #5399
Changes from 19 commits
c82dc08
aa15dc2
7d4f81e
b4e5d47
aa606d1
943147c
2f2e7b8
399f8c1
a7db4d5
a105688
d09c8b6
ba6839d
a298ae9
2c775d0
b8105bb
681919a
468bf93
c5e47b3
98188ce
8e755f2
5197920
398c6e2
57a1839
95b8a87
c0cc37f
fc49f86
34475a0
82792f1
4b70d6f
bb851d4
f7680f4
9af386c
729a592
0a61a69
d8d9a0f
608d57f
e5137cb
d0aa190
d0522b0
cd39b81
d0f77e9
5fd104e
0e3a810
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import multiprocessing | ||
|
||
import torch.distributed as dist | ||
|
||
from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer | ||
from vllm.utils import update_environment_variables | ||
|
||
|
||
def distributed_run(fn, world_size): | ||
number_of_processes = world_size | ||
processes = [] | ||
for i in range(number_of_processes): | ||
env = {} | ||
env['RANK'] = str(i) | ||
env['LOCAL_RANK'] = str(i) | ||
env['WORLD_SIZE'] = str(number_of_processes) | ||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes) | ||
env['MASTER_ADDR'] = 'localhost' | ||
env['MASTER_PORT'] = '12345' | ||
p = multiprocessing.Process(target=fn, args=(env, )) | ||
processes.append(p) | ||
p.start() | ||
|
||
for p in processes: | ||
p.join() | ||
|
||
for p in processes: | ||
assert p.exitcode == 0 | ||
|
||
|
||
def worker_fn_wrapper(fn): | ||
# `multiprocessing.Process` cannot accept environment variables directly | ||
# so we need to pass the environment variables as arguments | ||
# and update the environment variables in the function | ||
def wrapped_fn(env): | ||
update_environment_variables(env) | ||
dist.init_process_group(backend="gloo") | ||
fn() | ||
|
||
return wrapped_fn | ||
|
||
|
||
@worker_fn_wrapper | ||
def worker_fn(): | ||
broadcaster = ShmRingBuffer(dist.group.WORLD, 1024, 1) | ||
if dist.get_rank() == 0: | ||
broadcaster.broadcast_object(0) | ||
broadcaster.broadcast_object(1) | ||
else: | ||
a = broadcaster.broadcast_object(None) | ||
b = broadcaster.broadcast_object(None) | ||
assert a == 0 | ||
assert b == 1 | ||
|
||
|
||
def test_shm_broadcast(): | ||
distributed_run(worker_fn, 4) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
""" | ||
A shared memory ring buffer implementation for broadcast communication. | ||
It is optimized for the case where there is one writer and multiple readers. | ||
This way, we don't need locks to synchronize the access to the buffer. | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
import pickle | ||
import time | ||
from contextlib import contextmanager | ||
from multiprocessing import shared_memory | ||
from unittest.mock import patch | ||
|
||
import torch.distributed as dist | ||
from torch.distributed import ProcessGroup | ||
|
||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class ShmRingBuffer: | ||
|
||
# seconds to wait before warning about a potential blocking call | ||
WARNING_INTERVAL = 60 | ||
|
||
def __init__(self, pg: ProcessGroup, max_chunk_bytes, max_chunks): | ||
self.rank = dist.get_rank(pg) | ||
self.world_size = dist.get_world_size(pg) | ||
global_ranks = dist.get_process_group_ranks(pg) | ||
self.is_writer = self.rank == 0 | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.is_reader = not self.is_writer | ||
self.current_idx = 0 | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.max_chunk_bytes = max_chunk_bytes | ||
self.max_chunks = max_chunks | ||
total_bytes = (self.max_chunk_bytes + | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.world_size) * self.max_chunks | ||
self.data_offset = 0 | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.metadata_offset = self.max_chunk_bytes * self.max_chunks | ||
|
||
if self.is_writer: | ||
self.shared_memory = shared_memory.SharedMemory(create=True, | ||
size=total_bytes) | ||
# initialize the metadata section to 0 | ||
for i in range(self.metadata_offset, total_bytes): | ||
self.shared_memory.buf[i] = 0 | ||
dist.broadcast_object_list([self.shared_memory.name], | ||
src=global_ranks[0]) | ||
else: | ||
recv = [None] | ||
dist.broadcast_object_list(recv, src=global_ranks[0]) | ||
name = recv[0] | ||
# fix to https://stackoverflow.com/q/62748654/9191338 | ||
# Python incorrectly tracks shared memory even if it is not | ||
# created by the process. The following patch is a workaround. | ||
with patch("multiprocessing.resource_tracker.register", | ||
lambda *args, **kwargs: None): | ||
self.shared_memory = shared_memory.SharedMemory(name=name) | ||
|
||
@property | ||
def data(self): | ||
start = self.data_offset + self.current_idx * self.max_chunk_bytes | ||
end = start + self.max_chunk_bytes | ||
return memoryview(self.shared_memory.buf[start:end]) | ||
|
||
@property | ||
def metadata(self): | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
start = self.metadata_offset + self.current_idx * self.world_size | ||
end = start + self.world_size | ||
return memoryview(self.shared_memory.buf[start:end]) | ||
|
||
@contextmanager | ||
def acquire_write(self): | ||
assert self.is_writer, "Only writers can acquire write" | ||
start_index = self.current_idx | ||
start_time = time.time() | ||
while True: | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with self.metadata as metadata_buffer: | ||
read_count = sum(metadata_buffer[1:]) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
written_flag = metadata_buffer[0] | ||
if written_flag and read_count != self.world_size - 1: | ||
# this block is written and not read by all readers | ||
# try to write to the next block | ||
self.current_idx = (self.current_idx + 1) % self.max_chunks | ||
if self.current_idx == start_index: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this means we iterate all the data chunks right? Can you comment here? |
||
# no empty block found | ||
if time.time() - start_time > self.WARNING_INTERVAL: | ||
logger.warning( | ||
"No available block found in %s second. ", | ||
self.WARNING_INTERVAL) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# wait for a while (0.1 us) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this too small? (we call sleep 1e7 times per second). Have you measured the CPU overhead in this scenario by chance? |
||
time.sleep(1e-7) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
continue | ||
# found a block that is either | ||
# (1) not written | ||
# (2) read by all readers | ||
# let caller write to the buffer | ||
with self.data as buf: | ||
yield buf | ||
|
||
# caller has written to the buffer | ||
# reset the state | ||
metadata_buffer[0] = 1 | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for i in range(1, self.world_size): | ||
metadata_buffer[i] = 0 | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
break | ||
|
||
@contextmanager | ||
def acquire_read(self): | ||
assert self.is_reader, "Only readers can acquire read" | ||
start_index = self.current_idx | ||
start_time = time.time() | ||
while True: | ||
with self.metadata as metadata_buffer: | ||
read_flag = metadata_buffer[self.rank] | ||
written_flag = metadata_buffer[0] | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not written_flag or read_flag: | ||
# this block is either | ||
# (1) not written | ||
# (2) already read by this reader | ||
# try to read the next block | ||
self.current_idx = (self.current_idx + 1) % self.max_chunks | ||
if self.current_idx == start_index: | ||
# no block found | ||
if time.time() - start_time > self.WARNING_INTERVAL: | ||
logger.warning( | ||
"No available block found in %s second. ", | ||
self.WARNING_INTERVAL) | ||
# wait for a while (0.1 us) | ||
time.sleep(1e-7) | ||
continue | ||
# found a block that is not read by this reader | ||
# let caller read from the buffer | ||
with self.data as buf: | ||
yield buf | ||
|
||
# caller has read from the buffer | ||
# set the read flag | ||
metadata_buffer[self.rank] = 1 | ||
break | ||
|
||
def broadcast_object(self, obj=None): | ||
if self.is_writer: | ||
serialized_obj = pickle.dumps(obj, | ||
protocol=pickle.HIGHEST_PROTOCOL) | ||
if len(serialized_obj) > self.max_chunk_bytes: | ||
raise RuntimeError( | ||
f"{len(serialized_obj)=} larger than the allowed value " | ||
f"{self.max_chunk_bytes}," | ||
"Please increase the max_chunk_bytes parameter.") | ||
with self.acquire_write() as buf: | ||
buf[:len(serialized_obj)] = serialized_obj | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return obj | ||
else: | ||
with self.acquire_read() as buf: | ||
# no need to know the size of serialized object | ||
# pickle format itself contains the size information internally | ||
# see https://docs.python.org/3/library/pickle.html | ||
obj = pickle.loads(buf) | ||
youkaichao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return obj | ||
|
||
def __del__(self): | ||
if self.is_writer: | ||
self.shared_memory.close() | ||
self.shared_memory.unlink() | ||
else: | ||
self.shared_memory.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -98,6 +98,7 @@ class GroupCoordinator: | |
# communicators are only created for world size > 1 | ||
pynccl_comm: Optional[Any] # PyNccl communicator | ||
ca_comm: Optional[Any] # Custom allreduce communicator | ||
shm_broadcaster: Optional[Any] # shared memory broadcaster | ||
|
||
def __init__( | ||
self, | ||
|
@@ -162,6 +163,12 @@ def __init__( | |
else: | ||
self.ca_comm = None | ||
|
||
from vllm.distributed.device_communicators.shm_broadcast import ( | ||
ShmRingBuffer) | ||
self.shm_broadcaster: Optional[ShmRingBuffer] = None | ||
if self.world_size > 1 and is_in_the_same_node(self.cpu_group): | ||
self.shm_broadcaster = ShmRingBuffer(self.cpu_group, 1 << 20, 6) | ||
|
||
@property | ||
def first_rank(self): | ||
"""Return the global rank of the first process in the group""" | ||
|
@@ -324,6 +331,30 @@ def broadcast(self, input_: torch.Tensor, src: int = 0): | |
group=self.device_group) | ||
return input_ | ||
|
||
def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit; what about we just support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the benefit here is obvious enough, so we don't need to give the control to users. If This is the same case for |
||
"""Broadcast the input object. | ||
NOTE: `src` is the local rank of the source rank. | ||
""" | ||
assert src < self.world_size, f"Invalid src rank ({src})" | ||
|
||
# Bypass the function if we are using only 1 GPU. | ||
if self.world_size == 1: | ||
return obj | ||
if self.shm_broadcaster is not None: | ||
assert src == 0, "Shared memory broadcaster only supports src=0" | ||
return self.shm_broadcaster.broadcast_object(obj) | ||
if self.rank_in_group == src: | ||
torch.distributed.broadcast_object_list([obj], | ||
src=self.ranks[src], | ||
group=self.cpu_group) | ||
return obj | ||
else: | ||
recv = [None] | ||
torch.distributed.broadcast_object_list(recv, | ||
src=self.ranks[src], | ||
group=self.cpu_group) | ||
return recv[0] | ||
|
||
def broadcast_object_list(self, | ||
obj_list: List[Any], | ||
src: int = 0, | ||
|
@@ -371,9 +402,7 @@ def broadcast_tensor_dict( | |
# `metadata_list` lives in CPU memory. | ||
# `broadcast_object_list` has serialization & deserialization, | ||
# all happening on CPU. Therefore, we can use the CPU group. | ||
torch.distributed.broadcast_object_list([metadata_list], | ||
src=src, | ||
group=metadata_group) | ||
self.broadcast_object(metadata_list, src=src) | ||
async_handles = [] | ||
for tensor in tensor_list: | ||
if tensor.numel() == 0: | ||
|
@@ -396,14 +425,10 @@ def broadcast_tensor_dict( | |
async_handle.wait() | ||
|
||
else: | ||
recv_metadata_list = [None] | ||
torch.distributed.broadcast_object_list(recv_metadata_list, | ||
src=src, | ||
group=metadata_group) | ||
assert recv_metadata_list[0] is not None | ||
metadata_list = self.broadcast_object(None, src=src) | ||
tensor_dict = {} | ||
async_handles = [] | ||
for key, value in recv_metadata_list[0]: | ||
for key, value in metadata_list: | ||
if isinstance(value, TensorMetadata): | ||
tensor = torch.empty(value.size, | ||
dtype=value.dtype, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this create side effects to rest of the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is local to the process created in this test.