From e782d66b6109335898efc9bc70a567d48d2ee383 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Tue, 5 Nov 2024 08:13:15 +0000 Subject: [PATCH] use torch ipc --- csrc/custom_all_reduce.cu | 32 +++---- csrc/custom_all_reduce.cuh | 63 ++++++------- csrc/custom_all_reduce_test.cu | 24 ++--- csrc/ops.h | 10 +-- csrc/torch_bindings.cpp | 9 +- vllm/_custom_ops.py | 14 ++- .../device_communicators/custom_all_reduce.py | 90 ++++++++++--------- 7 files changed, 115 insertions(+), 127 deletions(-) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 9b82bec44c3c6..83e813d82014b 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -9,28 +9,24 @@ using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, +fptr_t init_custom_ar(const std::vector& ipc_tensors, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink) { - int world_size = offsets.size(); + int world_size = ipc_tensors.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); - if (world_size != handles.size()) - throw std::invalid_argument( - "handles length should equal to offsets length"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); - cudaIpcMemHandle_t ipc_handles[8]; + vllm::Signal* ipc_ptrs[8]; for (int i = 0; i < world_size; i++) { - std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + ipc_ptrs[i] = reinterpret_cast(ipc_tensors[i].data_ptr()); } - return (fptr_t) new vllm::CustomAllreduce( - reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), - rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); + return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), + rank_data.numel(), rank, world_size, + full_nvlink); } /** @@ -115,11 +111,15 @@ void dispose(fptr_t _fa) { int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets) { +void register_buffer(fptr_t _fa, + const std::vector& ipc_tensors) { auto fa = reinterpret_cast(_fa); - fa->register_buffer(handles, offsets, t.data_ptr()); + TORCH_CHECK(ipc_tensors.size() == fa->world_size_); + void* ipc_ptrs[8]; + for (int i = 0; i < ipc_tensors.size(); i++) { + ipc_ptrs[i] = reinterpret_cast(ipc_tensors[i].data_ptr()); + } + fa->register_buffer(ipc_ptrs); } std::tuple> get_graph_buffer_ipc_meta( diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index a2f7e43300002..f48969eeddb0d 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -297,34 +297,25 @@ class CustomAllreduce { std::map ipc_handles_; /** - * meta is a pointer to device metadata and temporary buffer for allreduce. + * Signals are an array of ipc-enabled buffers from all ranks. + * For each of the buffer, the layout is as follows: + * | -- sizeof(Signal) -- | ------ a few MB ----- | + * The first section is for allreduce synchronization, and the second section + * is for storing the intermediate results required by some allreduce algos. * - * There's a total of sizeof(Signal) of prefix before the actual data, - * so meta + 1 points to actual temporary buffer. - * - * note: this class does not own any device memory. Any required buffers - * are passed in from the constructor + * Note: this class does not own any device memory. Any required buffers + * are passed in from the constructor. */ - CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, - const cudaIpcMemHandle_t* handles, - const std::vector& offsets, int rank, - bool full_nvlink = true) + CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, + int rank, int world_size, bool full_nvlink = true) : rank_(rank), - world_size_(offsets.size()), + world_size_(world_size), full_nvlink_(full_nvlink), - self_sg_(meta), + self_sg_(signals[rank]), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - Signal* rank_sg; - if (i != rank_) { - char* handle = open_ipc_handle(&handles[i]); - handle += offsets[i]; - rank_sg = (Signal*)handle; - } else { - rank_sg = self_sg_; - } - sg_.signals[i] = rank_sg; + sg_.signals[i] = signals[i]; } } @@ -370,26 +361,22 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - void register_buffer(const std::vector& handles, - const std::vector& offsets, void* self) { + /** + * Register already-shared IPC pointers. + */ + void register_buffer(void** ptrs) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { - if (i != rank_) { - char* handle = open_ipc_handle(handles[i].data()); - handle += offsets[i]; - data.ptrs[i] = handle; - } else { - data.ptrs[i] = self; - } + data.ptrs[i] = ptrs[i]; } auto d_data = d_rank_data_base_++; CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); - buffers_[self] = d_data; + buffers_[ptrs[rank_]] = d_data; } - // note: when registering graph buffers, we intentionally choose to not + // Note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some // addresses, they will be registered again. This is to account for the remote // possibility of different allocation patterns between ranks. For example, @@ -424,11 +411,13 @@ class CustomAllreduce { } /** - * This is the result after careful grid search. Using 36 blocks give the best - * or close to the best runtime on the devices I tried: A100, A10, A30, T4, - * V100. You'll notice that NCCL kernels also only take a small amount of SMs. - * Not quite sure the underlying reason, but my guess is that too many SMs - * will cause contention on NVLink bus. + * Performs allreduce, assuming input has already been registered. + * + * Block and grid default configs are results after careful grid search. Using + * 36 blocks give the best or close to the best runtime on the devices I + * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only + * take a small amount of SMs. Not quite sure the underlying reason, but my + * guess is that too many SMs will cause contention on NVLink bus. */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index 376687e91cfda..b59ea40d980f4 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); - std::vector offsets(nRanks, 0); - vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, - offsets, myRank); + vllm::Signal* ipc_ptrs[8]; + for (int i = 0; i < nRanks; i++) { + if (i == myRank) + ipc_ptrs[i] = buffer; + else + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i], + cudaIpcMemLazyEnablePeerAccess)); + } + vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks); auto* self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { - std::vector handles; - handles.reserve(nRanks); + void* data[8]; for (int i = 0; i < nRanks; i++) { - char* begin = (char*)&data_handles[i]; - char* end = (char*)&data_handles[i + 1]; - handles.emplace_back(begin, end); + data[i] = + ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); } - std::vector offsets(nRanks, - sizeof(vllm::Signal) + data_size * sizeof(T)); - fa.register_buffer(handles, offsets, self_data); + fa.register_buffer(data); } double* ground_truth; diff --git a/csrc/ops.h b/csrc/ops.h index c50eb39a3dacc..1930c77bcb805 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -199,18 +199,14 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, #ifndef USE_ROCM using fptr_t = int64_t; -fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, - const std::vector& handles, - const std::vector& offsets, int64_t rank, - bool full_nvlink); +fptr_t init_custom_ar(const std::vector& ipc_tensors, + torch::Tensor& rank_data, int64_t rank, bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, torch::Tensor& t, - const std::vector& handles, - const std::vector& offsets); +void register_buffer(fptr_t _fa, const std::vector& ipc_tensors); std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b8185c24d5628..953c89aceb049 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -411,9 +411,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def( - "init_custom_ar(Tensor meta, Tensor rank_data, " - "str[] handles, int[] offsets, int rank, " - "bool full_nvlink) -> int"); + "init_custom_ar(Tensor[] ipc_tensors, Tensor rank_data, " + "int rank, bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); @@ -427,9 +426,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { custom_ar.def("dispose", &dispose); custom_ar.def("meta_size", &meta_size); - custom_ar.def( - "register_buffer(int fa, Tensor t, str[] handles, " - "int[] offsets) -> ()"); + custom_ar.def("register_buffer(int fa, Tensor[] ipc_tensors) -> ()"); custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 46a2fb8bc80a2..b823c846fdd95 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -912,11 +912,10 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar -def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, - handles: List[str], offsets: List[int], rank: int, - full_nvlink: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, - offsets, rank, full_nvlink) +def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, + rank: int, full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, + full_nvlink) def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: @@ -936,9 +935,8 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, t: torch.Tensor, handles: List[str], - offsets: List[int]) -> None: - return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) +def register_buffer(fa: int, ipc_tensors: List[torch.Tensor]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index c3632aee6d11a..dd6f200962351 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from torch.multiprocessing.reductions import reduce_tensor import vllm.envs as envs from vllm import _custom_ops as ops @@ -145,7 +146,7 @@ def __init__(self, return self.disabled = False - # buffers memory are owned by this Python class and passed to C++ + # Buffers memory are owned by this Python class and passed to C++ # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. @@ -168,10 +169,11 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size - handles, offsets = self._get_ipc_meta(self.meta) + # Retain reference to IPC tensors to prevent garbage collection. + self._ipc_tensors: List[torch.Tensor] = [] self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, - offsets, rank, self.full_nvlink) + self._ptr = ops.init_custom_ar(self._get_ipc_tensors(self.meta), + self.rank_data, rank, self.full_nvlink) self.register_buffer(self.buffer) @contextmanager @@ -189,33 +191,30 @@ def capture(self): if not self.disabled: self.register_graph_buffers() - def _get_ipc_meta(self, inp: torch.Tensor): - data = inp.untyped_storage()._share_cuda_() - handle = data[1] - # https://github.com/pytorch/pytorch/pull/130890 changes - # the binary format of the ipc handle - # it starts from pytorch 2.5 - if len(handle) > 64: - assert len(handle) == 66 - # only support SHAREABLE_HANDLE_VERSION = 1 - assert int(handle[0]) == 1 - # only support SHAREABLE_CUDA_MALLOC = 'c' - assert handle[1] == ord("c") - handle = handle[2:] - # TODO: support expandable segment - shard_data = ( - handle, # ipc handle to base ptr - data[3], # offset of base ptr - ) - return self._gather_ipc_meta(shard_data) - - def _gather_ipc_meta(self, shard_data): - # Note: don't use `[[None]] * self.world_size` here - # because it will create a list of the same reference - all_data: List[Optional[Any]] = [[None] - for i in range(self.world_size)] - all_data[self.rank][0] = shard_data - + def _get_ipc_tensors(self, inp: torch.Tensor) -> List[torch.Tensor]: + """Gather the ipc-enabled tensors of `inp` from all ranks.""" + all_meta = self._all_gather_object(reduce_tensor(inp)) + all_tensors = [] + for i, obj in enumerate(all_meta): + func = obj[0][0] + args = list(obj[0][1]) + # This might break in the future since what `args` encompasses + # may change. + args[6] = inp.device.index + if i != self.rank: + all_tensors.append(func(*args)) + else: + all_tensors.append(inp) + self._ipc_tensors.extend(all_tensors) + return all_tensors + + def _all_gather_object(self, data: Any) -> List[List[Any]]: + """All gather serializable objects.""" + all_data: List[List[Any]] = [[None] for i in range(self.world_size)] + all_data[self.rank][0] = data + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. ranks = dist.get_process_group_ranks(group=self.group) ranks.sort() for i, rank in enumerate(ranks): @@ -224,9 +223,12 @@ def _gather_ipc_meta(self, shard_data): group=self.group, device="cpu") - # we cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. + return all_data + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference. + all_data = self._all_gather_object(shard_data) handles = [] offsets = [] @@ -236,8 +238,7 @@ def _gather_ipc_meta(self, shard_data): return handles, offsets def register_buffer(self, inp: torch.Tensor): - handles, offsets = self._get_ipc_meta(inp) - ops.register_buffer(self._ptr, inp, handles, offsets) + ops.register_buffer(self._ptr, self._get_ipc_tensors(inp)) def register_graph_buffers(self): handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) @@ -260,23 +261,30 @@ def should_custom_ar(self, inp: torch.Tensor): return inp_size < self.max_size return False - # all reduce, assuming inp tensor is IPC registered with register_buffer, - # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + """Performs all reduce. + + This method assumes inp tensor is IPC registered with register_buffer, + or, in the context of cuda graphs, register_graph_buffers. + """ if out is None: out = torch.empty_like(inp) ops.all_reduce_reg(self._ptr, inp, out) return out - # all reduce, assuming inp tensor is NOT IPC registered def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + """Performs all reduce, assuming inp tensor is not IPC registered.""" if out is None: out = torch.empty_like(inp) ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - # when custom allreduce is disabled, this will be None + """Conditionally performs custom allreduce. + + Returns the allreduced result, or None if custom allreduce is not + suitable for the given Tensor under the current context. + """ if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: @@ -293,8 +301,6 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # gains of using custom kernels return self.all_reduce_unreg(input) - return None - def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr)