Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[coll] Reduce the scope of lock in the event loop. #9784

Merged
merged 7 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,19 +412,24 @@ class TCPSocket {
return Success();
}

void SetKeepAlive() {
[[nodiscard]] Result SetKeepAlive() {
std::int32_t keepalive = 1;
xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
0);
auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
sizeof(keepalive));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP keeaplive.");
}
return Success();
}

void SetNoDelay() {
[[nodiscard]] Result SetNoDelay() {
std::int32_t tcp_no_delay = 1;
xgboost_CHECK_SYS_CALL(
setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay)),
0);
auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
sizeof(tcp_no_delay));
if (rc != 0) {
return system::FailWithCode("Failed to set TCP no delay.");
}
return Success();
}

/**
Expand Down
4 changes: 2 additions & 2 deletions rabit/src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,9 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
utils::Assert(!all_link.sock.BadSocket(), "ReConnectLink: bad socket");
// set the socket to non-blocking mode, enable TCP keepalive
CHECK(all_link.sock.NonBlocking(true).OK());
all_link.sock.SetKeepAlive();
CHECK(all_link.sock.SetKeepAlive().OK());
if (rabit_enable_tcp_no_delay) {
all_link.sock.SetNoDelay();
CHECK(all_link.sock.SetNoDelay().OK());
}
if (tree_neighbors.count(all_link.rank) != 0) {
if (all_link.rank == parent_rank) {
Expand Down
11 changes: 5 additions & 6 deletions src/collective/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <algorithm> // for min
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, int8_t
#include <utility> // for move
#include <vector> // for vector

#include "../data/array_interface.h" // for Type, DispatchDType
Expand Down Expand Up @@ -47,7 +48,7 @@ Result RingScatterReduceTyped(Comm const& comm, common::Span<std::int8_t> data,
auto seg = s_buf.subspan(0, recv_seg.size());

prev_ch->RecvAll(seg);
auto rc = prev_ch->Block();
auto rc = comm.Block();
if (!rc.OK()) {
return rc;
}
Expand Down Expand Up @@ -83,11 +84,9 @@ Result RingAllreduce(Comm const& comm, common::Span<std::int8_t> data, Func cons
auto prev_ch = comm.Chan(prev);
auto next_ch = comm.Chan(next);

rc = RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
if (!rc.OK()) {
return rc;
}
return comm.Block();
return std::move(rc) << [&] {
return RingAllgather(comm, data, n_bytes_in_seg, 1, prev_ch, next_ch);
} << [&] { return comm.Block(); };
});
}
} // namespace xgboost::collective::cpu_impl
35 changes: 22 additions & 13 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,28 @@ Comm::Comm(std::string const& host, std::int32_t port, std::chrono::seconds time
Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, std::int32_t retry,
std::string const& task_id, TCPSocket* out, std::int32_t rank,
std::int32_t world) {
// get information from tracker
// Get information from the tracker
CHECK(!info.host.empty());
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (!rc.OK()) {
return Fail("Failed to connect to the tracker.", std::move(rc));
}

TCPSocket& tracker = *out;
return std::move(rc)
<< [&] { return tracker.NonBlocking(false); }
<< [&] { return tracker.RecvTimeout(timeout); }
<< [&] { return proto::Magic{}.Verify(&tracker); }
<< [&] { return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id); };
return Success() << [&] {
auto rc = Connect(info.host, info.port, retry, timeout, out);
if (rc.OK()) {
return rc;
} else {
return Fail("Failed to connect to the tracker.", std::move(rc));
}
} << [&] {
return tracker.NonBlocking(false);
} << [&] {
return tracker.RecvTimeout(timeout);
} << [&] {
return proto::Magic{}.Verify(&tracker);
} << [&] {
return proto::Connect{}.WorkerSend(&tracker, world, rank, task_id);
} << [&] {
LOG(INFO) << "Task " << task_id << " connected to the tracker";
return Success();
};
}

[[nodiscard]] Result Comm::ConnectTracker(TCPSocket* out) const {
Expand Down Expand Up @@ -257,8 +266,8 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
CHECK(this->channels_.empty());
for (auto& w : workers) {
if (w) {
w->SetNoDelay();
rc = w->NonBlocking(true);
rc = std::move(rc) << [&] { return w->SetNoDelay(); } << [&] { return w->NonBlocking(true); }
<< [&] { return w->SetKeepAlive(); };
}
if (!rc.OK()) {
return rc;
Expand Down
88 changes: 63 additions & 25 deletions src/collective/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,26 @@
#include "xgboost/logging.h" // for CHECK

namespace xgboost::collective {
Result Loop::EmptyQueue() {
Result Loop::EmptyQueue(std::queue<Op>* p_queue) const {
timer_.Start(__func__);
auto error = [this] {
this->stop_ = true;
auto error = [this] { timer_.Stop(__func__); };

if (stop_) {
timer_.Stop(__func__);
};
return Success();
}

while (!queue_.empty() && !stop_) {
std::queue<Op> qcopy;
auto& qcopy = *p_queue;

// clear the copied queue
while (!qcopy.empty()) {
rabit::utils::PollHelper poll;
std::size_t n_ops = qcopy.size();

// watch all ops
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
// Iterate through all the ops for poll
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();

switch (op.code) {
case Op::kRead: {
Expand All @@ -40,6 +45,7 @@ Result Loop::EmptyQueue() {
return Fail("Invalid socket operation.");
}
}

qcopy.push(op);
}

Expand All @@ -51,10 +57,12 @@ Result Loop::EmptyQueue() {
error();
return rc;
}

// we wonldn't be here if the queue is empty.
CHECK(!qcopy.empty());

while (!qcopy.empty() && !stop_) {
// Iterate through all the ops for performing the operations
for (std::size_t i = 0; i < n_ops; ++i) {
auto op = qcopy.front();
qcopy.pop();

Expand All @@ -81,20 +89,21 @@ Result Loop::EmptyQueue() {
}

if (n_bytes_done == -1 && !system::LastErrorWouldBlock()) {
stop_ = true;
auto rc = system::FailWithCode("Invalid socket output.");
error();
return rc;
}

op.off += n_bytes_done;
CHECK_LE(op.off, op.n);

if (op.off != op.n) {
// not yet finished, push back to queue for next round.
queue_.push(op);
qcopy.push(op);
}
}
}

timer_.Stop(__func__);
return Success();
}
Expand All @@ -107,22 +116,42 @@ void Loop::Process() {
if (stop_) {
break;
}
CHECK(!mu_.try_lock());

this->rc_ = this->EmptyQueue();
if (!rc_.OK()) {
stop_ = true;
auto unlock_notify = [&](bool is_blocking) {
if (!is_blocking) {
return;
}
lock.unlock();
cv_.notify_one();
break;
}
};

CHECK(queue_.empty());
CHECK(!mu_.try_lock());
cv_.notify_one();
}
// move the queue
std::queue<Op> qcopy;
bool is_blocking = false;
while (!queue_.empty()) {
auto op = queue_.front();
queue_.pop();
if (op.code == Op::kBlock) {
is_blocking = true;
} else {
qcopy.push(op);
}
}
// unblock the queue
if (!is_blocking) {
lock.unlock();
}
// clear the queue
auto rc = this->EmptyQueue(&qcopy);
// Handle error
if (!rc.OK()) {
this->rc_ = std::move(rc);
unlock_notify(is_blocking);
return;
}

if (rc_.OK()) {
CHECK(queue_.empty());
CHECK(qcopy.empty());
unlock_notify(is_blocking);
}
}

Expand All @@ -140,6 +169,15 @@ Result Loop::Stop() {
return Success();
}

[[nodiscard]] Result Loop::Block() {
this->Submit(Op{Op::kBlock});
{
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return (this->queue_.empty()) || stop_; });
}
return std::move(rc_);
}

Loop::Loop(std::chrono::seconds timeout) : timeout_{timeout} {
timer_.Init(__func__);
worker_ = std::thread{[this] {
Expand Down
17 changes: 5 additions & 12 deletions src/collective/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ namespace xgboost::collective {
class Loop {
public:
struct Op {
enum Code : std::int8_t { kRead = 0, kWrite = 1 } code;
enum Code : std::int8_t { kRead = 0, kWrite = 1, kBlock = 2 } code;
std::int32_t rank{-1};
std::int8_t* ptr{nullptr};
std::size_t n{0};
TCPSocket* sock{nullptr};
std::size_t off{0};

explicit Op(Code c) : code{c} { CHECK(c == kBlock); }
Op(Code c, std::int32_t rank, std::int8_t* ptr, std::size_t n, TCPSocket* sock, std::size_t off)
: code{c}, rank{rank}, ptr{ptr}, n{n}, sock{sock}, off{off} {}
Op(Op const&) = default;
Expand All @@ -44,9 +45,9 @@ class Loop {
Result rc_;
bool stop_{false};
std::exception_ptr curr_exce_{nullptr};
common::Monitor timer_;
common::Monitor mutable timer_;

Result EmptyQueue();
Result EmptyQueue(std::queue<Op>* p_queue) const;
void Process();

public:
Expand All @@ -60,15 +61,7 @@ class Loop {
cv_.notify_one();
}

[[nodiscard]] Result Block() {
{
std::unique_lock lock{mu_};
cv_.notify_all();
}
std::unique_lock lock{mu_};
cv_.wait(lock, [this] { return this->queue_.empty() || stop_; });
return std::move(rc_);
}
[[nodiscard]] Result Block();

explicit Loop(std::chrono::seconds timeout);

Expand Down
9 changes: 6 additions & 3 deletions tests/cpp/collective/test_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,34 @@ class AllreduceWorker : public WorkerForTest {
void Basic() {
{
std::vector<double> data(13, 0.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
}
{
std::vector<double> data(1, 1.0);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
}
}

void Acc() {
std::vector<double> data(314, 1.5);
Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
ASSERT_TRUE(rc.OK());
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
Expand Down
Loading