Skip to content

Commit

Permalink
use torch ipc
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Nov 5, 2024
1 parent 1c45f4c commit e782d66
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 127 deletions.
32 changes: 16 additions & 16 deletions csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,24 @@
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));

fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
fptr_t init_custom_ar(const std::vector<torch::Tensor>& ipc_tensors,
torch::Tensor& rank_data, int64_t rank,
bool full_nvlink) {
int world_size = offsets.size();
int world_size = ipc_tensors.size();
if (world_size > 8)
throw std::invalid_argument("world size > 8 is not supported");
if (world_size % 2 != 0)
throw std::invalid_argument("Odd num gpus is not supported for now");
if (world_size != handles.size())
throw std::invalid_argument(
"handles length should equal to offsets length");
if (rank < 0 || rank >= world_size)
throw std::invalid_argument("invalid rank passed in");

cudaIpcMemHandle_t ipc_handles[8];
vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < world_size; i++) {
std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(ipc_tensors[i].data_ptr());
}
return (fptr_t) new vllm::CustomAllreduce(
reinterpret_cast<vllm::Signal*>(meta.data_ptr()), rank_data.data_ptr(),
rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(),
rank_data.numel(), rank, world_size,
full_nvlink);
}

/**
Expand Down Expand Up @@ -115,11 +111,15 @@ void dispose(fptr_t _fa) {

int64_t meta_size() { return sizeof(vllm::Signal); }

void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets) {
void register_buffer(fptr_t _fa,
const std::vector<torch::Tensor>& ipc_tensors) {
auto fa = reinterpret_cast<vllm::CustomAllreduce*>(_fa);
fa->register_buffer(handles, offsets, t.data_ptr());
TORCH_CHECK(ipc_tensors.size() == fa->world_size_);
void* ipc_ptrs[8];
for (int i = 0; i < ipc_tensors.size(); i++) {
ipc_ptrs[i] = reinterpret_cast<vllm::Signal*>(ipc_tensors[i].data_ptr());
}
fa->register_buffer(ipc_ptrs);
}

std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
Expand Down
63 changes: 26 additions & 37 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -297,34 +297,25 @@ class CustomAllreduce {
std::map<IPC_KEY, char*> ipc_handles_;

/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
* Signals are an array of ipc-enabled buffers from all ranks.
* For each of the buffer, the layout is as follows:
* | -- sizeof(Signal) -- | ------ a few MB ----- |
* The first section is for allreduce synchronization, and the second section
* is for storing the intermediate results required by some allreduce algos.
*
* There's a total of sizeof(Signal) of prefix before the actual data,
* so meta + 1 points to actual temporary buffer.
*
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
* Note: this class does not own any device memory. Any required buffers
* are passed in from the constructor.
*/
CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz,
const cudaIpcMemHandle_t* handles,
const std::vector<int64_t>& offsets, int rank,
bool full_nvlink = true)
CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz,
int rank, int world_size, bool full_nvlink = true)
: rank_(rank),
world_size_(offsets.size()),
world_size_(world_size),
full_nvlink_(full_nvlink),
self_sg_(meta),
self_sg_(signals[rank]),
d_rank_data_base_(reinterpret_cast<RankData*>(rank_data)),
d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
for (int i = 0; i < world_size_; i++) {
Signal* rank_sg;
if (i != rank_) {
char* handle = open_ipc_handle(&handles[i]);
handle += offsets[i];
rank_sg = (Signal*)handle;
} else {
rank_sg = self_sg_;
}
sg_.signals[i] = rank_sg;
sg_.signals[i] = signals[i];
}
}

Expand Down Expand Up @@ -370,26 +361,22 @@ class CustomAllreduce {
std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
}

void register_buffer(const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, void* self) {
/**
* Register already-shared IPC pointers.
*/
void register_buffer(void** ptrs) {
check_rank_data_capacity();
RankData data;
for (int i = 0; i < world_size_; i++) {
if (i != rank_) {
char* handle = open_ipc_handle(handles[i].data());
handle += offsets[i];
data.ptrs[i] = handle;
} else {
data.ptrs[i] = self;
}
data.ptrs[i] = ptrs[i];
}
auto d_data = d_rank_data_base_++;
CUDACHECK(
cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
buffers_[self] = d_data;
buffers_[ptrs[rank_]] = d_data;
}

// note: when registering graph buffers, we intentionally choose to not
// Note: when registering graph buffers, we intentionally choose to not
// deduplicate the addresses. That means if the allocator reuses some
// addresses, they will be registered again. This is to account for the remote
// possibility of different allocation patterns between ranks. For example,
Expand Down Expand Up @@ -424,11 +411,13 @@ class CustomAllreduce {
}

/**
* This is the result after careful grid search. Using 36 blocks give the best
* or close to the best runtime on the devices I tried: A100, A10, A30, T4,
* V100. You'll notice that NCCL kernels also only take a small amount of SMs.
* Not quite sure the underlying reason, but my guess is that too many SMs
* will cause contention on NVLink bus.
* Performs allreduce, assuming input has already been registered.
*
* Block and grid default configs are results after careful grid search. Using
* 36 blocks give the best or close to the best runtime on the devices I
* tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only
* take a small amount of SMs. Not quite sure the underlying reason, but my
* guess is that too many SMs will cause contention on NVLink bus.
*/
template <typename T>
void allreduce(cudaStream_t stream, T* input, T* output, int size,
Expand Down
24 changes: 13 additions & 11 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,26 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit,
void* rank_data;
size_t rank_data_sz = 16 * 1024 * 1024;
CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
std::vector<int64_t> offsets(nRanks, 0);
vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
offsets, myRank);
vllm::Signal* ipc_ptrs[8];
for (int i = 0; i < nRanks; i++) {
if (i == myRank)
ipc_ptrs[i] = buffer;
else
CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i],
cudaIpcMemLazyEnablePeerAccess));
}
vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks);
auto* self_data =
reinterpret_cast<T*>(reinterpret_cast<char*>(buffer) +
sizeof(vllm::Signal) + data_size * sizeof(T));
// hack buffer registration
{
std::vector<std::string> handles;
handles.reserve(nRanks);
void* data[8];
for (int i = 0; i < nRanks; i++) {
char* begin = (char*)&data_handles[i];
char* end = (char*)&data_handles[i + 1];
handles.emplace_back(begin, end);
data[i] =
((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T);
}
std::vector<int64_t> offsets(nRanks,
sizeof(vllm::Signal) + data_size * sizeof(T));
fa.register_buffer(handles, offsets, self_data);
fa.register_buffer(data);
}

double* ground_truth;
Expand Down
10 changes: 3 additions & 7 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,14 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,

#ifndef USE_ROCM
using fptr_t = int64_t;
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets, int64_t rank,
bool full_nvlink);
fptr_t init_custom_ar(const std::vector<torch::Tensor>& ipc_tensors,
torch::Tensor& rank_data, int64_t rank, bool full_nvlink);
void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out);
void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer,
torch::Tensor& out);
void dispose(fptr_t _fa);
int64_t meta_size();
void register_buffer(fptr_t _fa, torch::Tensor& t,
const std::vector<std::string>& handles,
const std::vector<int64_t>& offsets);
void register_buffer(fptr_t _fa, const std::vector<torch::Tensor>& ipc_tensors);
std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
Expand Down
9 changes: 3 additions & 6 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
"init_custom_ar(Tensor[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);

custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
Expand All @@ -427,9 +426,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);

custom_ar.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
custom_ar.def("register_buffer(int fa, Tensor[] ipc_tensors) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);

custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
Expand Down
14 changes: 6 additions & 8 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,11 +912,10 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:


# custom ar
def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
handles: List[str], offsets: List[int], rank: int,
full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles,
offsets, rank, full_nvlink)
def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor,
rank: int, full_nvlink: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
full_nvlink)


def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
Expand All @@ -936,9 +935,8 @@ def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()


def register_buffer(fa: int, t: torch.Tensor, handles: List[str],
offsets: List[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets)
def register_buffer(fa: int, ipc_tensors: List[torch.Tensor]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]:
Expand Down
Loading

0 comments on commit e782d66

Please sign in to comment.