Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to fail gracefull on Rabit::Init failure #51

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions doc/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
int a[N];
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
}
Expand Down Expand Up @@ -92,7 +94,9 @@ node 0 to all other nodes.
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
std::string s;
if (rabit::GetRank() == 0) s = "hello world";
printf("@node[%d] before-broadcast: s=\"%s\"\n",
Expand Down Expand Up @@ -153,7 +157,9 @@ you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/le
#include <rabit.h>
int main(int argc, char *argv[]) {
...
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
// load the latest checked model
int version = rabit::LoadCheckPoint(&model);
// initialize the model if it is the first version
Expand Down Expand Up @@ -206,7 +212,9 @@ using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
int a[N] = {0};
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
// lazy preparation function
auto prepare = [&]() {
printf("@node[%d] run prepare function\n", rabit::GetRank());
Expand Down
4 changes: 3 additions & 1 deletion guide/basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ int main(int argc, char *argv[]) {
N = atoi(argv[1]);
}
std::vector<int> a(N);
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
for (int i = 0; i < N; ++i) {
a[i] = rabit::GetRank() + i;
}
Expand Down
4 changes: 3 additions & 1 deletion guide/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
std::string s;
if (rabit::GetRank() == 0) s = "hello world";
printf("@node[%d] before-broadcast: s=\"%s\"\n",
Expand Down
4 changes: 3 additions & 1 deletion guide/lazy_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ using namespace rabit;
const int N = 3;
int main(int argc, char *argv[]) {
int a[N] = {0};
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
// lazy preparation function
auto prepare = [&]() {
printf("@node[%d] run prepare function\n", rabit::GetRank());
Expand Down
3 changes: 2 additions & 1 deletion include/rabit/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ typedef unsigned long rbt_ulong; // NOLINT(*)
* from environment variables.
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true on successfull init. May also exit(-1).
*/
RABIT_DLL void RabitInit(int argc, char *argv[]);
RABIT_DLL bool RabitInit(int argc, char *argv[]);

/*!
* \brief finalize the rabit engine,
Expand Down
2 changes: 1 addition & 1 deletion include/rabit/internal/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class IEngine {
};

/*! \brief initializes the engine module */
void Init(int argc, char *argv[]);
bool Init(int argc, char *argv[]);
/*! \brief finalizes the engine module */
void Finalize(void);
/*! \brief singleton method to get engine */
Expand Down
4 changes: 2 additions & 2 deletions include/rabit/internal/rabit-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype &
} // namespace op

// intialize the rabit engine
inline void Init(int argc, char *argv[]) {
engine::Init(argc, argv);
inline bool Init(int argc, char *argv[]) {
return engine::Init(argc, argv);
}
// finalize the rabit engine
inline void Finalize(void) {
Expand Down
3 changes: 2 additions & 1 deletion include/rabit/rabit.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ struct BitOR;
* \brief initializes rabit, call this once at the beginning of your program
* \param argc number of arguments in argv
* \param argv the array of input arguments
* \return true on successfull init. May also exit(-1).
*/
inline void Init(int argc, char *argv[]);
inline bool Init(int argc, char *argv[]);
/*!
* \brief finalizes the rabit engine, call this function after you finished with all the jobs
*/
Expand Down
3 changes: 2 additions & 1 deletion python/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def init(args=None, lib='standard', lib_dll=None):
_loadlib(lib, lib_dll)
arr = (ctypes.c_char_p * len(args))()
arr[:] = args
_LIB.RabitInit(len(args), arr)
if not _LIB.RabitInit(len(args), arr):
raise Error("Failed to initialize rabit")

def finalize():
"""Finalize the rabit engine.
Expand Down
33 changes: 23 additions & 10 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ AllreduceBase::AllreduceBase(void) {
connect_retry = 5;
hadoop_mode = 0;
version_number = 0;
die_on_init = true;
// 32 K items
reduce_ring_mincount = 32 << 10;
// tracker URL
Expand All @@ -51,7 +52,7 @@ AllreduceBase::AllreduceBase(void) {
}

// initialization function
void AllreduceBase::Init(int argc, char* argv[]) {
bool AllreduceBase::Init(int argc, char* argv[]) {
// setup from enviroment variables
// handler to get variables from env
for (size_t i = 0; i < env_vars.size(); ++i) {
Expand Down Expand Up @@ -116,7 +117,7 @@ void AllreduceBase::Init(int argc, char* argv[]) {
utils::Assert(all_links.size() == 0, "can only call Init once");
this->host_uri = utils::SockAddr::GetHostName();
// get information from tracker
this->ReConnectLinks();
return this->ReConnectLinks("start", die_on_init);
}

void AllreduceBase::Shutdown(void) {
Expand All @@ -128,7 +129,8 @@ void AllreduceBase::Shutdown(void) {

if (tracker_uri == "NULL") return;
// notify tracker rank i have shutdown
utils::TCPSocket tracker = this->ConnectTracker();
std::pair<utils::TCPSocket, bool> pair = this->ConnectTracker();
utils::TCPSocket tracker = std::get<0>(pair);
tracker.SendStr(std::string("shutdown"));
tracker.Close();
utils::TCPSocket::Finalize();
Expand All @@ -137,7 +139,8 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
if (tracker_uri == "NULL") {
utils::Printf("%s", msg.c_str()); return;
}
utils::TCPSocket tracker = this->ConnectTracker();
std::pair<utils::TCPSocket, bool> pair = this->ConnectTracker();
utils::TCPSocket tracker = std::get<0>(pair);
tracker.SendStr(std::string("print"));
tracker.SendStr(msg);
tracker.Close();
Expand Down Expand Up @@ -188,12 +191,13 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
if (!strcmp(name, "rabit_die_on_init")) die_on_init = !strcmp(val, "1");
}
/*!
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
std::pair<utils::TCPSocket, bool> AllreduceBase::ConnectTracker(const bool dieOnError) const {
int magic = kMagic;
// get information from tracker
utils::TCPSocket tracker;
Expand All @@ -204,7 +208,11 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) {
if (++retry >= connect_retry) {
fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str());
utils::Socket::Error("Connect");
if (dieOnError) {
utils::Socket::Error("Connect");
} else {
return std::make_pair(tracker, false);
}
} else {
fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str());
#ifdef _MSC_VER
Expand All @@ -229,18 +237,22 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size),
"ReConnectLink failure 3");
tracker.SendStr(task_id);
return tracker;
return std::make_pair(tracker, true);
}
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
*/
void AllreduceBase::ReConnectLinks(const char *cmd) {
bool AllreduceBase::ReConnectLinks(const char *cmd, bool dieOnError) {
// single node mode
if (tracker_uri == "NULL") {
rank = 0; world_size = 1; return;
rank = 0; world_size = 1; return true;
}
std::pair<utils::TCPSocket, bool> pair = this->ConnectTracker(dieOnError);
if (!std::get<1>(pair)) {
return false;
}
utils::TCPSocket tracker = this->ConnectTracker();
utils::TCPSocket tracker = std::get<0>(pair);
tracker.SendStr(std::string(cmd));

// the rank of previous link, next link in ring
Expand Down Expand Up @@ -382,6 +394,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) {
"cannot find prev ring in the link");
Assert(next_rank == -1 || ring_next != NULL,
"cannot find next ring in the link");
return true;
}
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
Expand Down
9 changes: 6 additions & 3 deletions src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define RABIT_ALLREDUCE_BASE_H_

#include <vector>
#include <utility>
#include <string>
#include <algorithm>
#include "../include/rabit/internal/utils.h"
Expand All @@ -38,7 +39,7 @@ class AllreduceBase : public IEngine {
AllreduceBase(void);
virtual ~AllreduceBase(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
// shutdown the engine
virtual void Shutdown(void);
/*!
Expand Down Expand Up @@ -363,13 +364,13 @@ class AllreduceBase : public IEngine {
* \brief initialize connection to the tracker
* \return a socket that initializes the connection
*/
utils::TCPSocket ConnectTracker(void) const;
std::pair<utils::TCPSocket, bool> ConnectTracker(const bool dieOnError = true) const;
/*!
* \brief connect to the tracker to fix the the missing links
* this function is also used when the engine start up
* \param cmd possible command to sent to tracker
*/
void ReConnectLinks(const char *cmd = "start");
bool ReConnectLinks(const char *cmd = "start", bool dieOnError = true);
/*!
* \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure
*
Expand Down Expand Up @@ -521,6 +522,8 @@ class AllreduceBase : public IEngine {
int world_size;
// connect retry time
int connect_retry;
// die with exit(-1) if we fail init
bool die_on_init;
};
} // namespace engine
} // namespace rabit
Expand Down
7 changes: 5 additions & 2 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ AllreduceRobust::AllreduceRobust(void) {
env_vars.push_back("rabit_global_replica");
env_vars.push_back("rabit_local_replica");
}
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
bool AllreduceRobust::Init(int argc, char* argv[]) {
if (!AllreduceBase::Init(argc, argv)) {
return false;
}
if (num_global_replica == 0) {
result_buffer_round = -1;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
return true;
}
/*! \brief shutdown the engine */
void AllreduceRobust::Shutdown(void) {
Expand Down
2 changes: 1 addition & 1 deletion src/allreduce_robust.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase {
AllreduceRobust(void);
virtual ~AllreduceRobust(void) {}
// initialize the manager
virtual void Init(int argc, char* argv[]);
virtual bool Init(int argc, char* argv[]);
/*! \brief shutdown the engine */
virtual void Shutdown(void);
/*!
Expand Down
4 changes: 2 additions & 2 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ struct WriteWrapper : public Serializable {
} // namespace c_api
} // namespace rabit

void RabitInit(int argc, char *argv[]) {
rabit::Init(argc, argv);
bool RabitInit(int argc, char *argv[]) {
return rabit::Init(argc, argv);
}

void RabitFinalize() {
Expand Down
4 changes: 2 additions & 2 deletions src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ struct ThreadLocalEntry {
typedef ThreadLocalStore<ThreadLocalEntry> EngineThreadLocal;

/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
ThreadLocalEntry* e = EngineThreadLocal::Get();
utils::Check(e->engine.get() == nullptr,
"rabit::Init is already called in this thread");
e->initialized = true;
e->engine.reset(new Manager());
e->engine->Init(argc, argv);
return e->engine->Init(argc, argv);
}

/*! \brief finalize syncrhonization module */
Expand Down
3 changes: 2 additions & 1 deletion src/engine_empty.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class EmptyEngine : public IEngine {
EmptyEngine manager;

/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
return true;
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
Expand Down
3 changes: 2 additions & 1 deletion src/engine_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ class MPIEngine : public IEngine {
MPIEngine manager;

/*! \brief intiialize the synchronization module */
void Init(int argc, char *argv[]) {
bool Init(int argc, char *argv[]) {
MPI::Init(argc, argv);
return true;
}
/*! \brief finalize syncrhonization module */
void Finalize(void) {
Expand Down
4 changes: 3 additions & 1 deletion test/lazy_recover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ int main(int argc, char *argv[]) {
return 0;
}
int n = atoi(argv[1]);
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
std::string name = rabit::GetProcessorName();
Expand Down
4 changes: 3 additions & 1 deletion test/local_recover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ int main(int argc, char *argv[]) {
return 0;
}
int n = atoi(argv[1]);
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
std::string name = rabit::GetProcessorName();
Expand Down
4 changes: 3 additions & 1 deletion test/model_recover.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ int main(int argc, char *argv[]) {
return 0;
}
int n = atoi(argv[1]);
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
std::string name = rabit::GetProcessorName();
Expand Down
4 changes: 3 additions & 1 deletion test/speed_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ int main(int argc, char *argv[]) {
int n = atoi(argv[1]);
int nrep = atoi(argv[2]);
utils::Check(nrep >= 1, "need to at least repeat running once");
rabit::Init(argc, argv);
if (!rabit::Init(argc, argv)) {
return -1;
}
//int rank = rabit::GetRank();
int nproc = rabit::GetWorldSize();
std::string name = rabit::GetProcessorName();
Expand Down