Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Allreduce. #9679

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R-package/src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
1 change: 1 addition & 0 deletions R-package/src/Makevars.win
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ OBJECTS= \
$(PKGROOT)/src/logging.o \
$(PKGROOT)/src/global_config.o \
$(PKGROOT)/src/collective/allgather.o \
$(PKGROOT)/src/collective/allreduce.o \
$(PKGROOT)/src/collective/broadcast.o \
$(PKGROOT)/src/collective/comm.o \
$(PKGROOT)/src/collective/tracker.o \
Expand Down
90 changes: 90 additions & 0 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include "allreduce.h"

#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <vector> // for vector

#include "../data/array_interface.h" // for Type, DispatchDType
#include "allgather.h" // for RingAllgather
#include "comm.h" // for Comm
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span

namespace xgboost::collective::cpu_impl {
template <typename T>
Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
std::size_t n_bytes_in_seg, Func const& op) {
auto rank = comm.Rank();
auto world = comm.World();

auto dst_rank = BootstrapNext(rank, world);
auto src_rank = BootstrapPrev(rank, world);
auto next_ch = comm.Chan(dst_rank);
auto prev_ch = comm.Chan(src_rank);

std::vector<std::int8_t> buffer(n_bytes_in_seg, 0);
auto s_buf = common::Span{buffer.data(), buffer.size()};

for (std::int32_t r = 0; r < world - 1; ++r) {
// send to ring next
auto send_off = ((rank + world - r) % world) * n_bytes_in_seg;
send_off = std::min(send_off, data.size_bytes());
auto seg_nbytes = std::min(data.size_bytes() - send_off, n_bytes_in_seg);
auto send_seg = data.subspan(send_off, seg_nbytes);

next_ch->SendAll(send_seg);

// receive from ring prev
auto recv_off = ((rank + world - r - 1) % world) * n_bytes_in_seg;
recv_off = std::min(recv_off, data.size_bytes());
seg_nbytes = std::min(data.size_bytes() - recv_off, n_bytes_in_seg);
CHECK_EQ(seg_nbytes % sizeof(T), 0);
auto recv_seg = data.subspan(recv_off, seg_nbytes);
auto seg = s_buf.subspan(0, recv_seg.size());

prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
if (!rc.OK()) {
return rc;
}

// accumulate to recv_seg
CHECK_EQ(seg.size(), recv_seg.size());
op(seg, recv_seg);
}

return Success();
}

Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type) {
return DispatchDType(type, [&](auto t) {
using T = decltype(t);
// Divide the data into segments according to the number of workers.
auto n_bytes_elem = sizeof(T);
CHECK_EQ(data.size_bytes() % n_bytes_elem, 0);
auto n = data.size_bytes() / n_bytes_elem;
auto world = comm.World();
auto n_bytes_in_seg = common::DivRoundUp(n, world) * sizeof(T);
auto rc = RingScatterReduceTyped<T>(comm, data, n_bytes_in_seg, op);
if (!rc.OK()) {
return rc;
}

auto prev = BootstrapPrev(comm.Rank(), comm.World());
auto next = BootstrapNext(comm.Rank(), comm.World());
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);

rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
});
}
} // namespace xgboost::collective::cpu_impl
39 changes: 39 additions & 0 deletions src/collective/allreduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#pragma once
#include <cstdint> // for int8_t
#include <functional> // for function
#include <type_traits> // for is_invocable_v

#include "../data/array_interface.h" // for ArrayInterfaceHandler
#include "comm.h" // for Comm, RestoreType
#include "xgboost/collective/result.h" // for Result
#include "xgboost/span.h" // for Span

namespace xgboost::collective {
namespace cpu_impl {
using Func =
std::function<void(common::Span<std::int8_t const> lhs, common::Span<std::int8_t> out)>;

Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func const& op,
ArrayInterfaceHandler::Type type);
} // namespace cpu_impl

template <typename T, typename Fn>
std::enable_if_t<std::is_invocable_v<Fn, common::Span<T const>, common::Span<T>>, Result> Allreduce(
Comm const& comm, common::Span<T> data, Fn redop) {
auto erased = EraseType(data);
auto type = ToDType<T>::kType;

auto erased_fn = [type, redop](common::Span<std::int8_t const> lhs,
common::Span<std::int8_t> out) {
CHECK_EQ(lhs.size(), out.size()) << "Invalid input for reduction.";
auto lhs_t = RestoreType<T const>(lhs);
auto rhs_t = RestoreType<T>(out);
redop(lhs_t, rhs_t);
};

return cpu_impl::RingAllreduce(comm, erased, erased_fn, type);
}
} // namespace xgboost::collective
97 changes: 54 additions & 43 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include <utility>
#include <vector>

#include "../common/bitfield.h"
#include "../common/bitfield.h" // for RBitField8
#include "../common/common.h"
#include "../common/error_msg.h" // for NoF128
#include "xgboost/base.h"
Expand Down Expand Up @@ -104,7 +104,20 @@ struct ArrayInterfaceErrors {
*/
class ArrayInterfaceHandler {
public:
enum Type : std::int8_t { kF2, kF4, kF8, kF16, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
enum Type : std::int8_t {
kF2 = 0,
kF4 = 1,
kF8 = 2,
kF16 = 3,
kI1 = 4,
kI2 = 5,
kI4 = 6,
kI8 = 7,
kU1 = 8,
kU2 = 9,
kU4 = 10,
kU8 = 11,
};

template <typename PtrType>
static PtrType GetPtrFromArrayData(Object::Map const &obj) {
Expand Down Expand Up @@ -587,75 +600,73 @@ class ArrayInterface {
ArrayInterfaceHandler::Type type{ArrayInterfaceHandler::kF16};
};

template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
// Only used for cuDF at the moment.
CHECK_EQ(array.valid.Capacity(), 0);
auto dispatch = [&](auto t) {
using T = std::remove_const_t<decltype(t)> const;
// Set the data size to max as we don't know the original size of a sliced array:
//
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
// size 24 based on the slice.
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
std::numeric_limits<std::size_t>::max()},
array.shape, array.strides, device});
};
switch (array.type) {
template <typename Fn>
auto DispatchDType(ArrayInterfaceHandler::Type dtype, Fn dispatch) {
switch (dtype) {
case ArrayInterfaceHandler::kF2: {
#if defined(XGBOOST_USE_CUDA)
dispatch(__half{});
#endif
return dispatch(__half{});
#else
LOG(FATAL) << "half type is only supported for CUDA input.";
break;
#endif
}
case ArrayInterfaceHandler::kF4: {
dispatch(float{});
break;
return dispatch(float{});
}
case ArrayInterfaceHandler::kF8: {
dispatch(double{});
break;
return dispatch(double{});
}
case ArrayInterfaceHandler::kF16: {
using T = long double;
CHECK(sizeof(long double) == 16) << error::NoF128();
dispatch(T{});
break;
CHECK(sizeof(T) == 16) << error::NoF128();
return dispatch(T{});
}
case ArrayInterfaceHandler::kI1: {
dispatch(std::int8_t{});
break;
return dispatch(std::int8_t{});
}
case ArrayInterfaceHandler::kI2: {
dispatch(std::int16_t{});
break;
return dispatch(std::int16_t{});
}
case ArrayInterfaceHandler::kI4: {
dispatch(std::int32_t{});
break;
return dispatch(std::int32_t{});
}
case ArrayInterfaceHandler::kI8: {
dispatch(std::int64_t{});
break;
return dispatch(std::int64_t{});
}
case ArrayInterfaceHandler::kU1: {
dispatch(std::uint8_t{});
break;
return dispatch(std::uint8_t{});
}
case ArrayInterfaceHandler::kU2: {
dispatch(std::uint16_t{});
break;
return dispatch(std::uint16_t{});
}
case ArrayInterfaceHandler::kU4: {
dispatch(std::uint32_t{});
break;
return dispatch(std::uint32_t{});
}
case ArrayInterfaceHandler::kU8: {
dispatch(std::uint64_t{});
break;
return dispatch(std::uint64_t{});
}
}

return std::result_of_t<Fn(std::int8_t)>();
}

template <std::int32_t D, typename Fn>
void DispatchDType(ArrayInterface<D> const array, DeviceOrd device, Fn fn) {
// Only used for cuDF at the moment.
CHECK_EQ(array.valid.Capacity(), 0);
auto dispatch = [&](auto t) {
using T = std::remove_const_t<decltype(t)> const;
// Set the data size to max as we don't know the original size of a sliced array:
//
// Slicing an array A with shape (4, 2, 3) and stride (6, 3, 1) by [:, 1, :] results
// in an array B with shape (4, 3) and strides (6, 1). We can't calculate the original
// size 24 based on the slice.
fn(linalg::TensorView<T, D>{common::Span<T const>{static_cast<T *>(array.data),
std::numeric_limits<std::size_t>::max()},
array.shape, array.strides, device});
};
DispatchDType(array.type, dispatch);
}

/**
Expand Down
72 changes: 72 additions & 0 deletions tests/cpp/collective/test_allreduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>

#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for WorkerForTest, TestDistributed

namespace xgboost::collective {

namespace {
class AllreduceWorker : public WorkerForTest {
public:
using WorkerForTest::WorkerForTest;

void Basic() {
{
std::vector<double> data(13, 0.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
}
{
std::vector<double> data(1, 1.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
}
}

void Acc() {
std::vector<double> data(314, 1.5);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
};

class AllreduceTest : public SocketTest {};
} // namespace

TEST_F(AllreduceTest, Basic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Basic();
});
}

TEST_F(AllreduceTest, Sum) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Acc();
});
}
} // namespace xgboost::collective
Loading
Loading