From 633c33dc77a4258dd7b787d3ee112b3e9f6cd8d4 Mon Sep 17 00:00:00 2001 From: Erik Bernhardson Date: Tue, 5 Dec 2017 00:23:12 -0800 Subject: [PATCH] Add option to fail gracefull on Rabit::Init failure When running xgboost4j-spark the process that rabit is initialized in has other things going on, such as caches of data that is expensive to recompute. On the rare occasion that rabit fails to connect to the tracker rabit performs an exit(-1) which throws away everything that was going on in the application. It would be nice to somehow provide explicit failure handling everywhere, but changing the external api to that level would be quite a large breaking change. This patch takes a very targeted approach changing only the Init call to return a boolean indicating success. This is disabled by default and must be provided as part of the initialization parameters. --- doc/guide.md | 16 +++++++++++---- guide/basic.cc | 4 +++- guide/broadcast.cc | 4 +++- guide/lazy_allreduce.cc | 4 +++- include/rabit/c_api.h | 3 ++- include/rabit/internal/engine.h | 2 +- include/rabit/internal/rabit-inl.h | 4 ++-- include/rabit/rabit.h | 3 ++- python/rabit.py | 3 ++- src/allreduce_base.cc | 33 +++++++++++++++++++++--------- src/allreduce_base.h | 9 +++++--- src/allreduce_robust.cc | 7 +++++-- src/allreduce_robust.h | 2 +- src/c_api.cc | 4 ++-- src/engine.cc | 4 ++-- src/engine_empty.cc | 3 ++- src/engine_mpi.cc | 3 ++- test/lazy_recover.cc | 4 +++- test/local_recover.cc | 4 +++- test/model_recover.cc | 4 +++- test/speed_test.cc | 4 +++- 21 files changed, 85 insertions(+), 39 deletions(-) diff --git a/doc/guide.md b/doc/guide.md index 39a69e9e..5fc1eb24 100644 --- a/doc/guide.md +++ b/doc/guide.md @@ -25,7 +25,9 @@ using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { int a[N]; - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } for (int i = 0; i < N; ++i) { a[i] = rabit::GetRank() + i; } @@ -92,7 +94,9 @@ node 0 to all other nodes. using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } std::string s; if (rabit::GetRank() == 0) s = "hello world"; printf("@node[%d] before-broadcast: s=\"%s\"\n", @@ -153,7 +157,9 @@ you can also refer to [wormhole](https://github.com/dmlc/wormhole/blob/master/le #include int main(int argc, char *argv[]) { ... - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } // load the latest checked model int version = rabit::LoadCheckPoint(&model); // initialize the model if it is the first version @@ -206,7 +212,9 @@ using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { int a[N] = {0}; - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } // lazy preparation function auto prepare = [&]() { printf("@node[%d] run prepare function\n", rabit::GetRank()); diff --git a/guide/basic.cc b/guide/basic.cc index d08397b5..f59cc293 100644 --- a/guide/basic.cc +++ b/guide/basic.cc @@ -16,7 +16,9 @@ int main(int argc, char *argv[]) { N = atoi(argv[1]); } std::vector a(N); - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } for (int i = 0; i < N; ++i) { a[i] = rabit::GetRank() + i; } diff --git a/guide/broadcast.cc b/guide/broadcast.cc index 9e360d8d..c296db7a 100644 --- a/guide/broadcast.cc +++ b/guide/broadcast.cc @@ -2,7 +2,9 @@ using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } std::string s; if (rabit::GetRank() == 0) s = "hello world"; printf("@node[%d] before-broadcast: s=\"%s\"\n", diff --git a/guide/lazy_allreduce.cc b/guide/lazy_allreduce.cc index b4b816fa..58b1fa3d 100644 --- a/guide/lazy_allreduce.cc +++ b/guide/lazy_allreduce.cc @@ -11,7 +11,9 @@ using namespace rabit; const int N = 3; int main(int argc, char *argv[]) { int a[N] = {0}; - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } // lazy preparation function auto prepare = [&]() { printf("@node[%d] run prepare function\n", rabit::GetRank()); diff --git a/include/rabit/c_api.h b/include/rabit/c_api.h index a05ebd4d..0eb2e745 100644 --- a/include/rabit/c_api.h +++ b/include/rabit/c_api.h @@ -32,8 +32,9 @@ typedef unsigned long rbt_ulong; // NOLINT(*) * from environment variables. * \param argc number of arguments in argv * \param argv the array of input arguments + * \return true on successfull init. May also exit(-1). */ -RABIT_DLL void RabitInit(int argc, char *argv[]); +RABIT_DLL bool RabitInit(int argc, char *argv[]); /*! * \brief finalize the rabit engine, diff --git a/include/rabit/internal/engine.h b/include/rabit/internal/engine.h index 6a7dfe4a..8c118889 100644 --- a/include/rabit/internal/engine.h +++ b/include/rabit/internal/engine.h @@ -161,7 +161,7 @@ class IEngine { }; /*! \brief initializes the engine module */ -void Init(int argc, char *argv[]); +bool Init(int argc, char *argv[]); /*! \brief finalizes the engine module */ void Finalize(void); /*! \brief singleton method to get engine */ diff --git a/include/rabit/internal/rabit-inl.h b/include/rabit/internal/rabit-inl.h index 7536c184..5bb1866a 100644 --- a/include/rabit/internal/rabit-inl.h +++ b/include/rabit/internal/rabit-inl.h @@ -103,8 +103,8 @@ inline void Reducer(const void *src_, void *dst_, int len, const MPI::Datatype & } // namespace op // intialize the rabit engine -inline void Init(int argc, char *argv[]) { - engine::Init(argc, argv); +inline bool Init(int argc, char *argv[]) { + return engine::Init(argc, argv); } // finalize the rabit engine inline void Finalize(void) { diff --git a/include/rabit/rabit.h b/include/rabit/rabit.h index 1eda2ea7..c553cf07 100644 --- a/include/rabit/rabit.h +++ b/include/rabit/rabit.h @@ -73,8 +73,9 @@ struct BitOR; * \brief initializes rabit, call this once at the beginning of your program * \param argc number of arguments in argv * \param argv the array of input arguments + * \return true on successfull init. May also exit(-1). */ -inline void Init(int argc, char *argv[]); +inline bool Init(int argc, char *argv[]); /*! * \brief finalizes the rabit engine, call this function after you finished with all the jobs */ diff --git a/python/rabit.py b/python/rabit.py index d57587ba..4008273d 100644 --- a/python/rabit.py +++ b/python/rabit.py @@ -103,7 +103,8 @@ def init(args=None, lib='standard', lib_dll=None): _loadlib(lib, lib_dll) arr = (ctypes.c_char_p * len(args))() arr[:] = args - _LIB.RabitInit(len(args), arr) + if not _LIB.RabitInit(len(args), arr): + raise Error("Failed to initialize rabit") def finalize(): """Finalize the rabit engine. diff --git a/src/allreduce_base.cc b/src/allreduce_base.cc index 862187bc..d9e5963e 100644 --- a/src/allreduce_base.cc +++ b/src/allreduce_base.cc @@ -27,6 +27,7 @@ AllreduceBase::AllreduceBase(void) { connect_retry = 5; hadoop_mode = 0; version_number = 0; + die_on_init = true; // 32 K items reduce_ring_mincount = 32 << 10; // tracker URL @@ -51,7 +52,7 @@ AllreduceBase::AllreduceBase(void) { } // initialization function -void AllreduceBase::Init(int argc, char* argv[]) { +bool AllreduceBase::Init(int argc, char* argv[]) { // setup from enviroment variables // handler to get variables from env for (size_t i = 0; i < env_vars.size(); ++i) { @@ -116,7 +117,7 @@ void AllreduceBase::Init(int argc, char* argv[]) { utils::Assert(all_links.size() == 0, "can only call Init once"); this->host_uri = utils::SockAddr::GetHostName(); // get information from tracker - this->ReConnectLinks(); + return this->ReConnectLinks("start", die_on_init); } void AllreduceBase::Shutdown(void) { @@ -128,7 +129,8 @@ void AllreduceBase::Shutdown(void) { if (tracker_uri == "NULL") return; // notify tracker rank i have shutdown - utils::TCPSocket tracker = this->ConnectTracker(); + std::pair pair = this->ConnectTracker(); + utils::TCPSocket tracker = std::get<0>(pair); tracker.SendStr(std::string("shutdown")); tracker.Close(); utils::TCPSocket::Finalize(); @@ -137,7 +139,8 @@ void AllreduceBase::TrackerPrint(const std::string &msg) { if (tracker_uri == "NULL") { utils::Printf("%s", msg.c_str()); return; } - utils::TCPSocket tracker = this->ConnectTracker(); + std::pair pair = this->ConnectTracker(); + utils::TCPSocket tracker = std::get<0>(pair); tracker.SendStr(std::string("print")); tracker.SendStr(msg); tracker.Close(); @@ -188,12 +191,13 @@ void AllreduceBase::SetParam(const char *name, const char *val) { if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) { connect_retry = atoi(val); } + if (!strcmp(name, "rabit_die_on_init")) die_on_init = !strcmp(val, "1"); } /*! * \brief initialize connection to the tracker * \return a socket that initializes the connection */ -utils::TCPSocket AllreduceBase::ConnectTracker(void) const { +std::pair AllreduceBase::ConnectTracker(const bool dieOnError) const { int magic = kMagic; // get information from tracker utils::TCPSocket tracker; @@ -204,7 +208,11 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { if (!tracker.Connect(utils::SockAddr(tracker_uri.c_str(), tracker_port))) { if (++retry >= connect_retry) { fprintf(stderr, "connect to (failed): [%s]\n", tracker_uri.c_str()); - utils::Socket::Error("Connect"); + if (dieOnError) { + utils::Socket::Error("Connect"); + } else { + return std::make_pair(tracker, false); + } } else { fprintf(stderr, "retry connect to ip(retry time %d): [%s]\n", retry, tracker_uri.c_str()); #ifdef _MSC_VER @@ -229,18 +237,22 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const { Assert(tracker.SendAll(&world_size, sizeof(world_size)) == sizeof(world_size), "ReConnectLink failure 3"); tracker.SendStr(task_id); - return tracker; + return std::make_pair(tracker, true); } /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up */ -void AllreduceBase::ReConnectLinks(const char *cmd) { +bool AllreduceBase::ReConnectLinks(const char *cmd, bool dieOnError) { // single node mode if (tracker_uri == "NULL") { - rank = 0; world_size = 1; return; + rank = 0; world_size = 1; return true; + } + std::pair pair = this->ConnectTracker(dieOnError); + if (!std::get<1>(pair)) { + return false; } - utils::TCPSocket tracker = this->ConnectTracker(); + utils::TCPSocket tracker = std::get<0>(pair); tracker.SendStr(std::string(cmd)); // the rank of previous link, next link in ring @@ -382,6 +394,7 @@ void AllreduceBase::ReConnectLinks(const char *cmd) { "cannot find prev ring in the link"); Assert(next_rank == -1 || ring_next != NULL, "cannot find next ring in the link"); + return true; } /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure diff --git a/src/allreduce_base.h b/src/allreduce_base.h index ef5567af..a8b701b7 100644 --- a/src/allreduce_base.h +++ b/src/allreduce_base.h @@ -13,6 +13,7 @@ #define RABIT_ALLREDUCE_BASE_H_ #include +#include #include #include #include "../include/rabit/internal/utils.h" @@ -38,7 +39,7 @@ class AllreduceBase : public IEngine { AllreduceBase(void); virtual ~AllreduceBase(void) {} // initialize the manager - virtual void Init(int argc, char* argv[]); + virtual bool Init(int argc, char* argv[]); // shutdown the engine virtual void Shutdown(void); /*! @@ -363,13 +364,13 @@ class AllreduceBase : public IEngine { * \brief initialize connection to the tracker * \return a socket that initializes the connection */ - utils::TCPSocket ConnectTracker(void) const; + std::pair ConnectTracker(const bool dieOnError = true) const; /*! * \brief connect to the tracker to fix the the missing links * this function is also used when the engine start up * \param cmd possible command to sent to tracker */ - void ReConnectLinks(const char *cmd = "start"); + bool ReConnectLinks(const char *cmd = "start", bool dieOnError = true); /*! * \brief perform in-place allreduce, on sendrecvbuf, this function can fail, and will return the cause of failure * @@ -521,6 +522,8 @@ class AllreduceBase : public IEngine { int world_size; // connect retry time int connect_retry; + // die with exit(-1) if we fail init + bool die_on_init; }; } // namespace engine } // namespace rabit diff --git a/src/allreduce_robust.cc b/src/allreduce_robust.cc index a48a349a..919be842 100644 --- a/src/allreduce_robust.cc +++ b/src/allreduce_robust.cc @@ -31,13 +31,16 @@ AllreduceRobust::AllreduceRobust(void) { env_vars.push_back("rabit_global_replica"); env_vars.push_back("rabit_local_replica"); } -void AllreduceRobust::Init(int argc, char* argv[]) { - AllreduceBase::Init(argc, argv); +bool AllreduceRobust::Init(int argc, char* argv[]) { + if (!AllreduceBase::Init(argc, argv)) { + return false; + } if (num_global_replica == 0) { result_buffer_round = -1; } else { result_buffer_round = std::max(world_size / num_global_replica, 1); } + return true; } /*! \brief shutdown the engine */ void AllreduceRobust::Shutdown(void) { diff --git a/src/allreduce_robust.h b/src/allreduce_robust.h index c8860822..c1f39615 100644 --- a/src/allreduce_robust.h +++ b/src/allreduce_robust.h @@ -24,7 +24,7 @@ class AllreduceRobust : public AllreduceBase { AllreduceRobust(void); virtual ~AllreduceRobust(void) {} // initialize the manager - virtual void Init(int argc, char* argv[]); + virtual bool Init(int argc, char* argv[]); /*! \brief shutdown the engine */ virtual void Shutdown(void); /*! diff --git a/src/c_api.cc b/src/c_api.cc index 8c789c08..e6f54175 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -162,8 +162,8 @@ struct WriteWrapper : public Serializable { } // namespace c_api } // namespace rabit -void RabitInit(int argc, char *argv[]) { - rabit::Init(argc, argv); +bool RabitInit(int argc, char *argv[]) { + return rabit::Init(argc, argv); } void RabitFinalize() { diff --git a/src/engine.cc b/src/engine.cc index c958932b..644adc7f 100644 --- a/src/engine.cc +++ b/src/engine.cc @@ -43,13 +43,13 @@ struct ThreadLocalEntry { typedef ThreadLocalStore EngineThreadLocal; /*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { +bool Init(int argc, char *argv[]) { ThreadLocalEntry* e = EngineThreadLocal::Get(); utils::Check(e->engine.get() == nullptr, "rabit::Init is already called in this thread"); e->initialized = true; e->engine.reset(new Manager()); - e->engine->Init(argc, argv); + return e->engine->Init(argc, argv); } /*! \brief finalize syncrhonization module */ diff --git a/src/engine_empty.cc b/src/engine_empty.cc index 8177410a..78c65ece 100644 --- a/src/engine_empty.cc +++ b/src/engine_empty.cc @@ -77,7 +77,8 @@ class EmptyEngine : public IEngine { EmptyEngine manager; /*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { +bool Init(int argc, char *argv[]) { + return true; } /*! \brief finalize syncrhonization module */ void Finalize(void) { diff --git a/src/engine_mpi.cc b/src/engine_mpi.cc index 35283ad5..7d97b0b8 100644 --- a/src/engine_mpi.cc +++ b/src/engine_mpi.cc @@ -86,8 +86,9 @@ class MPIEngine : public IEngine { MPIEngine manager; /*! \brief intiialize the synchronization module */ -void Init(int argc, char *argv[]) { +bool Init(int argc, char *argv[]) { MPI::Init(argc, argv); + return true; } /*! \brief finalize syncrhonization module */ void Finalize(void) { diff --git a/test/lazy_recover.cc b/test/lazy_recover.cc index dd64294b..f2c95604 100644 --- a/test/lazy_recover.cc +++ b/test/lazy_recover.cc @@ -89,7 +89,9 @@ int main(int argc, char *argv[]) { return 0; } int n = atoi(argv[1]); - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::string name = rabit::GetProcessorName(); diff --git a/test/local_recover.cc b/test/local_recover.cc index a63bd2f8..7459de17 100644 --- a/test/local_recover.cc +++ b/test/local_recover.cc @@ -100,7 +100,9 @@ int main(int argc, char *argv[]) { return 0; } int n = atoi(argv[1]); - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::string name = rabit::GetProcessorName(); diff --git a/test/model_recover.cc b/test/model_recover.cc index a2709f89..7f5c05c2 100644 --- a/test/model_recover.cc +++ b/test/model_recover.cc @@ -90,7 +90,9 @@ int main(int argc, char *argv[]) { return 0; } int n = atoi(argv[1]); - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::string name = rabit::GetProcessorName(); diff --git a/test/speed_test.cc b/test/speed_test.cc index 8eb543de..c4f3b973 100644 --- a/test/speed_test.cc +++ b/test/speed_test.cc @@ -77,7 +77,9 @@ int main(int argc, char *argv[]) { int n = atoi(argv[1]); int nrep = atoi(argv[2]); utils::Check(nrep >= 1, "need to at least repeat running once"); - rabit::Init(argc, argv); + if (!rabit::Init(argc, argv)) { + return -1; + } //int rank = rabit::GetRank(); int nproc = rabit::GetWorldSize(); std::string name = rabit::GetProcessorName();