Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor device communicator to make allreduce more flexible #9295

Merged
merged 2 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/collective/communicator-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#pragma once
#include <string>
#include <vector>

#include "communicator.h"
#include "device_communicator.cuh"

namespace xgboost {
namespace collective {

/**
* @brief Reduce values from all processes and distribute the result back to all processes.
* @param device ID of the device.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
template <Operation op>
inline void AllReduce(int device, std::int8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt8, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint8_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt8, op);
}

template <Operation op>
inline void AllReduce(int device, std::int32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt32, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint32_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt32, op);
}

template <Operation op>
inline void AllReduce(int device, std::int64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kInt64, op);
}

template <Operation op>
inline void AllReduce(int device, std::uint64_t *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kUInt64, op);
}

template <Operation op>
inline void AllReduce(int device, float *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kFloat, op);
}

template <Operation op>
inline void AllReduce(int device, double *send_receive_buffer, size_t count) {
Communicator::GetDevice(device)->AllReduce(send_receive_buffer, count, DataType::kDouble, op);
}

/**
* @brief Gather variable-length values from all processes.
* @param device ID of the device.
* @param send_buffer Buffer storing the input data.
* @param length_bytes Length in bytes of the input data.
* @param segments Size of each segment.
* @param receive_buffer Buffer storing the output data.
*/
inline void AllGatherV(int device, void const *send_buffer, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *receive_buffer) {
Communicator::GetDevice(device)->AllGatherV(send_buffer, length_bytes, segments, receive_buffer);
}

/**
* @brief Synchronize device operations.
* @param device ID of the device.
*/
inline void Synchronize(int device) { Communicator::GetDevice(device)->Synchronize(); }

} // namespace collective
} // namespace xgboost
29 changes: 6 additions & 23 deletions src/collective/device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,15 @@ class DeviceCommunicator {
virtual ~DeviceCommunicator() = default;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @brief Combines values from all processes and distributes the result back to all processes.
*
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
* @param data_type Data type stored in the buffer.
* @param op The operation to perform.
*/
virtual void AllReduceSum(float *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(double *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(int64_t *send_receive_buffer, size_t count) = 0;

/**
* @brief Sum values from all processes and distribute the result back to all processes.
* @param send_receive_buffer Buffer storing the data.
* @param count Number of elements in the buffer.
*/
virtual void AllReduceSum(uint64_t *send_receive_buffer, size_t count) = 0;
virtual void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) = 0;

/**
* @brief Gather variable-length values from all processes.
Expand Down
38 changes: 11 additions & 27 deletions src/collective/device_communicator_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,18 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {

~DeviceCommunicatorAdapter() override = default;

void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kFloat>(send_receive_buffer, count);
}

void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kDouble>(send_receive_buffer, count);
}

void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kInt64>(send_receive_buffer, count);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}

void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<collective::DataType::kUInt64>(send_receive_buffer, count);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * GetTypeSize(data_type);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, op);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
Expand Down Expand Up @@ -77,20 +75,6 @@ class DeviceCommunicatorAdapter : public DeviceCommunicator {
}

private:
template <collective::DataType data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
auto size = count * sizeof(T);
host_buffer_.reserve(size);
dh::safe_cuda(cudaMemcpy(host_buffer_.data(), send_receive_buffer, size, cudaMemcpyDefault));
communicator_->AllReduce(host_buffer_.data(), count, data_type, collective::Operation::kSum);
dh::safe_cuda(cudaMemcpy(send_receive_buffer, host_buffer_.data(), size, cudaMemcpyDefault));
}

int const device_ordinal_;
Communicator *communicator_;
/// Host buffer used to call communicator functions.
Expand Down
84 changes: 62 additions & 22 deletions src/collective/nccl_device_communicator.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,18 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
}
}

void AllReduceSum(float *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclFloat>(send_receive_buffer, count);
}

void AllReduceSum(double *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclDouble>(send_receive_buffer, count);
}

void AllReduceSum(int64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclInt64>(send_receive_buffer, count);
}
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {
if (communicator_->GetWorldSize() == 1) {
return;
}

void AllReduceSum(uint64_t *send_receive_buffer, size_t count) override {
DoAllReduceSum<ncclUint64>(send_receive_buffer, count);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count,
GetNcclDataType(data_type), GetNcclRedOp(op), nccl_comm_,
cuda_stream_));
allreduce_bytes_ += count * GetTypeSize(data_type);
allreduce_calls_ += 1;
}

void AllGatherV(void const *send_buffer, size_t length_bytes, std::vector<std::size_t> *segments,
Expand Down Expand Up @@ -162,17 +160,59 @@ class NcclDeviceCommunicator : public DeviceCommunicator {
return id;
}

template <ncclDataType_t data_type, typename T>
void DoAllReduceSum(T *send_receive_buffer, size_t count) {
if (communicator_->GetWorldSize() == 1) {
return;
static ncclDataType_t GetNcclDataType(DataType const &data_type) {
ncclDataType_t result;
switch (data_type) {
case DataType::kInt8:
result = ncclInt8;
break;
case DataType::kUInt8:
result = ncclUint8;
break;
case DataType::kInt32:
result = ncclInt32;
break;
case DataType::kUInt32:
result = ncclUint32;
break;
case DataType::kInt64:
result = ncclInt64;
break;
case DataType::kUInt64:
result = ncclUint64;
break;
case DataType::kFloat:
result = ncclFloat;
break;
case DataType::kDouble:
result = ncclDouble;
break;
default:
LOG(FATAL) << "Unknown data type.";
}
return result;
}

dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(send_receive_buffer, send_receive_buffer, count, data_type, ncclSum,
nccl_comm_, cuda_stream_));
allreduce_bytes_ += count * sizeof(T);
allreduce_calls_ += 1;
static ncclRedOp_t GetNcclRedOp(Operation const &op) {
ncclRedOp_t result;
switch (op) {
case Operation::kMax:
result = ncclMax;
break;
case Operation::kMin:
result = ncclMin;
break;
case Operation::kSum:
result = ncclSum;
break;
case Operation::kBitwiseAND:
case Operation::kBitwiseOR:
case Operation::kBitwiseXOR:
LOG(FATAL) << "Not implemented yet.";
default:
LOG(FATAL) << "Unknown reduce operation.";
}
return result;
}

int const device_ordinal_;
Expand Down
13 changes: 6 additions & 7 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#include <memory>
#include <utility>

#include "../collective/communicator.h"
#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "categorical.h"
#include "common.h"
#include "device_helpers.cuh"
Expand Down Expand Up @@ -510,7 +509,6 @@ void SketchContainer::AllReduce() {
}

timer_.Start(__func__);
auto* communicator = collective::Communicator::GetDevice(device_);
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
collective::Allreduce<collective::Operation::kSum>(&global_sum_rows, 1);
Expand All @@ -531,14 +529,15 @@ void SketchContainer::AllReduce() {
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
communicator->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.size());
collective::AllReduce<collective::Operation::kSum>(device_, gathered_ptrs.data().get(),
gathered_ptrs.size());

// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
communicator->AllGatherV(this->Current().data().get(), dh::ToSpan(this->Current()).size_bytes(),
&recv_lengths, &recvbuf);
communicator->Synchronize();
collective::AllGatherV(device_, this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths, &recvbuf);
collective::Synchronize(device_);

// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);
Expand Down
5 changes: 2 additions & 3 deletions src/metric/auc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <tuple>
#include <utility>

#include "../collective/device_communicator.cuh"
#include "../collective/communicator-inl.cuh"
#include "../common/algorithm.cuh" // SegmentedArgSort
#include "../common/optional_weight.h" // OptionalWeights
#include "../common/threading_utils.cuh" // UnravelTrapeziodIdx,SegmentedTrapezoidThreads
Expand Down Expand Up @@ -205,8 +205,7 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
auto* communicator = collective::Communicator::GetDevice(device);
communicator->AllReduceSum(results.data(), results.size());
collective::AllReduce<collective::Operation::kSum>(device, results.data(), results.size());
}
auto reduce_in = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0), [=] XGBOOST_DEVICE(size_t i) {
Expand Down
6 changes: 3 additions & 3 deletions src/tree/fit_stump.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <cstddef> // std::size_t

#include "../collective/device_communicator.cuh" // DeviceCommunicator
#include "../collective/communicator-inl.cuh"
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
Expand Down Expand Up @@ -49,8 +49,8 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));

collective::DeviceCommunicator* communicator = collective::Communicator::GetDevice(ctx->gpu_id);
communicator->AllReduceSum(reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);
collective::AllReduce<collective::Operation::kSum>(
ctx->gpu_id, reinterpret_cast<double*>(d_sum.Values().data()), d_sum.Size() * 2);

thrust::for_each_n(policy, thrust::make_counting_iterator(0ul), n_targets,
[=] XGBOOST_DEVICE(std::size_t i) mutable {
Expand Down
Loading