Skip to content

Commit

Permalink
support larger cluster (#73)
Browse files Browse the repository at this point in the history
* fix error in #57, clean up comments and naming

* include missing packages, disable recovery tests for now

* disable local_recover tests until we have a bug fix

* support larger cluster

* fix lint, merge with master
  • Loading branch information
chenqin authored and CodingCat committed Oct 22, 2018
1 parent 69cdfae commit 3a35dab
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 123 deletions.
64 changes: 32 additions & 32 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#ifdef _MSC_VER
Sleep(1);
Sleep(retry << 1);
#else
sleep(1);
sleep(retry << 1);
#endif
continue;
}
Expand Down Expand Up @@ -454,47 +454,47 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand Down Expand Up @@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
Expand Down Expand Up @@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand All @@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
Expand Down Expand Up @@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
Expand Down Expand Up @@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
Expand Down
20 changes: 10 additions & 10 deletions src/allreduce_robust-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,30 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
if (parent_index == -1) {
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
// poll helper
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
Expand All @@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
Expand All @@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
Expand Down
44 changes: 22 additions & 22 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == sizeof(sig)) all_links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
Expand All @@ -343,23 +343,23 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
rsel.Poll();
}
for (int i = 0; i < nlink; ++i) {
if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock);
utils::PollHelper::WaitExcept(all_links[i].sock);
}
}
while (true) {
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
rsel.WatchRead(all_links[i].sock); finished = false;
}
}
if (finished) break;
rsel.Select();
rsel.Poll();
for (int i = 0; i < nlink; ++i) {
if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_read == 0) {
Expand Down Expand Up @@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
finished = false;
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
if (finished) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
Expand Down Expand Up @@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kPassData) {
const int pid = recv_link;
const size_t buffer_size = links[pid].buffer_size;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
size_t min_write = size;
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
Expand Down Expand Up @@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != read_end) {
selecter.WatchRead(prev.sock);
watcher.WatchRead(prev.sock);
finished = false;
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
watcher.WatchWrite(next.sock);
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
watcher.WatchException(prev.sock);
watcher.WatchException(next.sock);
if (finished) break;
selecter.Select();
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
watcher.Poll();
if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
Expand Down
Loading

0 comments on commit 3a35dab

Please sign in to comment.