Skip to content

Commit

Permalink
[coll] Reduce the scope of lock in the event loop. (#9784)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 15, 2023
1 parent 36a552a commit ada377c
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 70 deletions.
23 changes: 14 additions & 9 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,19 +412,24 @@ class TCPSocket {
return Success();
}

void SetKeepAlive() {
[[nodiscard]] Result SetKeepAlive() {
std::int32_t keepalive = 1;
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
0);
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
sizeof(keepalive));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP keeaplive.");
}
return Success();
}

void SetNoDelay() {
[[nodiscard]] Result SetNoDelay() {
std::int32_t tcp_no_delay = 1;
xgboost_CHECK_SYS_CALL(
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay)),
0);
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP no delay.");
}
return Success();
}

/**
Expand Down
4 changes: 2 additions & 2 deletions rabit/src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
CHECK(all_link.sock.NonBlocking(true).OK());
all_link.sock.SetKeepAlive();
CHECK(all_link.sock.SetKeepAlive().OK());
if (rabit_enable_tcp_no_delay) {
all_link.sock.SetNoDelay();
CHECK(all_link.sock.SetNoDelay().OK());
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {
Expand Down
11 changes: 5 additions & 6 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include <vector> // for vector

#include "../data/array_interface.h" // for Type, DispatchDType
Expand Down Expand Up @@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto seg = s_buf.subspan(0, recv_seg.size());

prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
auto rc = comm.Block();
if (!rc.OK()) {
return rc;
}
Expand Down Expand Up @@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);

rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
return std::move(rc) << [&] {
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
} << [&] { return comm.Block(); };
});
}
} // namespace xgboost::collective::cpu_impl
35 changes: 22 additions & 13 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// get information from tracker
// Get information from the tracker
CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}

TCPSocket& tracker = *out;
return std::move(rc)
<< [&] { return tracker.NonBlocking(false); }
<< [&] { return tracker.RecvTimeout(timeout); }
<< [&] { return proto::Magic{}.Verify(&tracker); }
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
return Success() << [&] {
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (rc.OK()) {
return rc;
} else {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
} << [&] {
return tracker.NonBlocking(false);
} << [&] {
return tracker.RecvTimeout(timeout);
} << [&] {
return proto::Magic{}.Verify(&tracker);
} << [&] {
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
} << [&] {
LOG(INFO) << "Task " << task_id << " connected to the tracker";
return Success();
};
}

[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
Expand Down Expand Up @@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
w->SetNoDelay();
rc = w->NonBlocking(true);
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
<< [&] { return w->SetKeepAlive(); };
}
if (!rc.OK()) {
return rc;
Expand Down
88 changes: 63 additions & 25 deletions src/collective/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,26 @@
#include "xgboost/logging.h" // for CHECK

namespace xgboost::collective {
Result Loop::EmptyQueue() {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
auto error = [this] { timer_.Stop(__func__); };

if (stop_) {
timer_.Stop(__func__);
};
return Success();
}

while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
auto& qcopy = *p_queue;

// clear the copied queue
while (!qcopy.empty()) {
rabit::utils::PollHelper poll;
std::size_t n_ops = qcopy.size();

// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();

switch (op.code) {
case Op::kRead: {
Expand All @@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
return Fail("Invalid socket operation.");
}
}

qcopy.push(op);
}

Expand All @@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
error();
return rc;
}

// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());

while (!qcopy.empty() && !stop_) {
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();

Expand All @@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
}

if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}

op.off += n_bytes_done;
CHECK_LE(op.off, op.n);

if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
qcopy.push(op);
}
}
}

timer_.Stop(__func__);
return Success();
}
Expand All @@ -107,22 +116,42 @@ void Loop::Process() {
if (stop_) {
break;
}
CHECK(!mu_.try_lock());

this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
auto unlock_notify = [&](bool is_blocking) {
if (!is_blocking) {
return;
}
lock.unlock();
cv_.notify_one();
break;
}
};

CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}
// move the queue
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
}
// unblock the queue
if (!is_blocking) {
lock.unlock();
}
// clear the queue
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
this->rc_ = std::move(rc);
unlock_notify(is_blocking);
return;
}

if (rc_.OK()) {
CHECK(queue_.empty());
CHECK(qcopy.empty());
unlock_notify(is_blocking);
}
}

Expand All @@ -140,6 +169,15 @@ Result Loop::Stop() {
return Success();
}

[[nodiscard]] Result Loop::Block() {
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
return std::move(rc_);
}

Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {
Expand Down
17 changes: 5 additions & 12 deletions src/collective/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};

explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Expand All @@ -44,9 +45,9 @@ class Loop {
Result rc_;
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;
common::Monitor mutable timer_;

Result EmptyQueue();
Result EmptyQueue(std::queue<Op>* p_queue) const;
void Process();

public:
Expand All @@ -60,15 +61,7 @@ class Loop {
cv_.notify_one();
}

[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
[[nodiscard]] Result Block();

explicit Loop(std::chrono::seconds timeout);

Expand Down
9 changes: 6 additions & 3 deletions tests/cpp/collective/test_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
void Basic() {
{
std::vector<double> data(13, 0.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
}
{
std::vector<double> data(1, 1.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
}
}

void Acc() {
std::vector<double> data(314, 1.5);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
Expand Down

0 comments on commit ada377c

Please sign in to comment.