Skip to content

Commit

Permalink
[core][distributed] improve shared memory broadcast (vllm-project#5754)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and prashantgupta24 committed Jul 1, 2024
1 parent 06baabc commit 7cd7a7a
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,26 @@ def __init__(self,
| written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
+--------------+--------------+--------------+-----+--------------+
The state of metadata is as follows:
(case 1) 0???...???: the block is not written yet, cannot read, can write
(case 2) 1000...000: the block is just written, can read, cannot write
(case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
(case 4) 1111...111: the block is written and read by all readers, cannot read, can write
State transition for readers:
When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
Only after the caller finishes reading the block, the reader can mark the block as read.
Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).
State transition for writer:
When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
can reset the reader flags to 0, and mark the block as written (from 0 to 1).
NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.
During creation, `name` is None and the buffer is created. We can pass the
created object to other processes by pickling it. The other processes will
get the name of the shared memory and open it, so that they can access the
Expand Down Expand Up @@ -81,10 +101,6 @@ def __init__(self,
lambda *args, **kwargs: None):
self.shared_memory = shared_memory.SharedMemory(name=name)
assert self.shared_memory.size == self.total_bytes_of_buffer
with memoryview(self.shared_memory.buf[self.metadata_offset:]
) as metadata_buffer:
tensor = torch.frombuffer(metadata_buffer, dtype=torch.uint8)
assert torch.all(tensor == 0)

def __reduce__(self):
return (
Expand Down Expand Up @@ -163,11 +179,15 @@ def acquire_write(self):
yield buf

# caller has written to the buffer
# mark the block as written
metadata_buffer[0] = 1
# NOTE: order is important here
# first set the read flags to 0
# then set the written flag to 1
# otherwise, the readers may think they already read the block
for i in range(1, self.buffer.n_reader + 1):
# set read flag to 0, meaning it is not read yet
metadata_buffer[i] = 0
# mark the block as written
metadata_buffer[0] = 1
break

@contextmanager
Expand Down Expand Up @@ -247,13 +267,15 @@ def create_from_process_group(pg: ProcessGroup,
buffer: ShmRingBuffer
if group_rank == writer_rank:
buffer = ShmRingBuffer(n_reader, max_chunk_bytes, max_chunks)
dist.broadcast_object_list([buffer], src=global_ranks[writer_rank])
dist.barrier(pg)
dist.broadcast_object_list([buffer],
src=global_ranks[writer_rank],
group=pg)
return ShmRingBufferIO(buffer, -1)
else:
recv = [None]
dist.broadcast_object_list(recv, src=global_ranks[writer_rank])
dist.barrier(pg)
dist.broadcast_object_list(recv,
src=global_ranks[writer_rank],
group=pg)
buffer = recv[0] # type: ignore
rest_ranks = [r for r in ranks_inside_group if r != writer_rank]
return ShmRingBufferIO(buffer, rest_ranks.index(group_rank))

0 comments on commit 7cd7a7a

Please sign in to comment.