Skip to content

Commit

Permalink
tracker support binary config format
Browse files Browse the repository at this point in the history
  • Loading branch information
chenqin committed Jun 19, 2019
1 parent a9d7331 commit 2a28e5e
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 33 deletions.
4 changes: 2 additions & 2 deletions include/rabit/internal/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class IEngine {
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg) = 0;
virtual void TrackerSetConfig(const std::string &key, const std::string &value) = 0;
virtual void TrackerGetConfig(const std::string& key, std::string* value) = 0;
virtual void TrackerSetConfig(const std::string &key, const int size, const void* value) = 0;
virtual void TrackerGetConfig(const std::string& key, const int size, void* value) = 0;
};

/*! \brief initializes the engine module */
Expand Down
27 changes: 11 additions & 16 deletions include/rabit/internal/rabit-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ inline void Broadcast(std::vector<DType> *sendrecv_data, int root, const char* c
Broadcast(&(*sendrecv_data)[0], size * sizeof(DType), root, caller);
}
}
inline void Broadcast(std::string *sendrecv_data, int root, const char* caller) {
inline void Broadcast(std::string *sendrecv_data, int root,
const char* caller) {
size_t size = sendrecv_data->length();
Broadcast(&size, sizeof(size), root, caller);
if (sendrecv_data->length() != size) {
Expand Down Expand Up @@ -182,12 +183,12 @@ inline void TrackerPrint(const std::string &msg) {
engine::GetEngine()->TrackerPrint(msg);
}

inline void TrackerSetConfig(const std::string &key, const std::string &value) {
engine::GetEngine()->TrackerSetConfig(key, value);
inline void TrackerSetConfig(const std::string &key, const int bsize, const void *value) {
engine::GetEngine()->TrackerSetConfig(key, bsize, value);
}

inline void TrackerGetConfig(const std::string &key, std::string* value) {
engine::GetEngine()->TrackerGetConfig(key, value);
inline void TrackerGetConfig(const std::string &key, const int bsize, void *value) {
engine::GetEngine()->TrackerGetConfig(key, bsize, value);
}

#ifndef RABIT_STRICT_CXX98_
Expand All @@ -202,36 +203,30 @@ inline void TrackerPrintf(const char *fmt, ...) {
TrackerPrint(msg);
}

inline void TrackerSetConfig(const char *key, const char *value, ...) {
inline void TrackerSetConfig(const char *key, const int bsize, const void *value, ...) {
const int kPrintBuffer = 1 << 10;
std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0');

va_list args1, args2;
va_start(args1, key);
va_start(args2, value);
vsnprintf(&k[0], kPrintBuffer, key, args1);
vsnprintf(&v[0], kPrintBuffer, value, args2);
va_end(args1);
va_end(args2);
k.resize(strlen(k.c_str()));
v.resize(strlen(v.c_str()));
engine::GetEngine()->TrackerSetConfig(k, v);
engine::GetEngine()->TrackerSetConfig(k, bsize, value);
}

inline void TrackerGetConfig(const char *key, char* value, ...) {
inline void TrackerGetConfig(const char *key, const int bsize, void* value, ...) {
const int kPrintBuffer = 1 << 10;
std::string k(kPrintBuffer, '\0'), v(kPrintBuffer, '\0');

va_list args1, args2;
va_list args1;
va_start(args1, key);
va_start(args2, value);
vsnprintf(&k[0], kPrintBuffer, key, args1);
vsnprintf(&v[0], kPrintBuffer, value, args2);
va_end(args1);
va_end(args2);
k.resize(strlen(k.c_str()));
v.resize(strlen(v.c_str()));
engine::GetEngine()->TrackerGetConfig(k, &v);
engine::GetEngine()->TrackerGetConfig(k, bsize, value);
}
#endif // RABIT_STRICT_CXX98_
// load latest check point
Expand Down
23 changes: 14 additions & 9 deletions include/rabit/rabit.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ inline void TrackerPrint(const std::string &msg);
* \param key configuration key
* \param value value of config
*/
inline void TrackerSetConfig(const std::string &key, const std::string &value);
inline void TrackerSetConfig(const std::string &key, const int bsize, const void* value);
/*!
* \brief get config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerGetConfig(const std::string &key, std::string* value);
inline void TrackerGetConfig(const std::string &key, const int bsize, void* value);

#ifndef RABIT_STRICT_CXX98_
/*!
Expand All @@ -127,13 +127,13 @@ inline void TrackerPrintf(const char *fmt, ...);
* \param key configuration key
* \param value value of config
*/
inline void TrackerSetConfig(const char *key, const char *value, ...);
inline void TrackerSetConfig(const char *key, const int bsize, const void* value, ...);
/*!
* \brief get config to tracker,
* \param key configuration key
* \param value value of config
*/
inline void TrackerGetConfig(const char *key, char* value, ...);
inline void TrackerGetConfig(const char *key, const int bsize, void* value, ...);
#endif // RABIT_STRICT_CXX98_
/*!
* \brief broadcasts a memory region to every node from the root
Expand All @@ -143,7 +143,8 @@ inline void TrackerGetConfig(const char *key, char* value, ...);
* \param size the data size
* \param root the process root
*/
inline void Broadcast(void *sendrecv_data, size_t size, int root, const char* caller = __builtin_FUNCTION());
inline void Broadcast(void *sendrecv_data, size_t size, int root,
const char* caller = __builtin_FUNCTION());
/*!
* \brief broadcasts an std::vector<DType> to every node from root
* \param sendrecv_data the pointer to send/receive vector,
Expand All @@ -153,14 +154,16 @@ inline void Broadcast(void *sendrecv_data, size_t size, int root, const char* ca
* that can be directly transmitted by sending the sizeof(DType)
*/
template<typename DType>
inline void Broadcast(std::vector<DType> *sendrecv_data, int root, const char* caller = __builtin_FUNCTION());
inline void Broadcast(std::vector<DType> *sendrecv_data, int root,
const char* caller = __builtin_FUNCTION());
/*!
* \brief broadcasts a std::string to every node from the root
* \param sendrecv_data the pointer to the send/receive buffer,
* for the receiver, the vector does not need to be pre-allocated
* \param root the process root
*/
inline void Broadcast(std::string *sendrecv_data, int root, const char* caller = __builtin_FUNCTION());
inline void Broadcast(std::string *sendrecv_data, int root,
const char* caller = __builtin_FUNCTION());
/*!
* \brief performs in-place Allreduce on sendrecvbuf
* this function is NOT thread-safe
Expand All @@ -185,7 +188,8 @@ inline void Broadcast(std::string *sendrecv_data, int root, const char* caller =
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
void (*prepare_fun)(void *) = NULL,
void *prepare_arg = NULL, const char* caller = __builtin_FUNCTION());
void *prepare_arg = NULL,
const char* caller = __builtin_FUNCTION());
// C++11 support for lambda prepare function
#if DMLC_USE_CXX11
/*!
Expand Down Expand Up @@ -214,7 +218,8 @@ inline void Allreduce(DType *sendrecvbuf, size_t count,
*/
template<typename OP, typename DType>
inline void Allreduce(DType *sendrecvbuf, size_t count,
std::function<void()> prepare_fun, const char* caller = __builtin_FUNCTION());
std::function<void()> prepare_fun,
const char* caller = __builtin_FUNCTION());
#endif // C++11
/*!
* \brief loads the latest check point
Expand Down
9 changes: 5 additions & 4 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,19 +146,20 @@ void AllreduceBase::TrackerPrint(const std::string &msg) {
tracker.Close();
}

void AllreduceBase::TrackerSetConfig(const std::string &key, const std::string &value) {
void AllreduceBase::TrackerSetConfig(const std::string &key, const int bytesize, const void* value) {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("set"));
tracker.SendStr(key);
tracker.SendStr(value);
tracker.Send(&bytesize, sizeof(int));
tracker.SendAll(value, bytesize);
tracker.Close();
}

void AllreduceBase::TrackerGetConfig(const std::string &key, std::string* value) {
void AllreduceBase::TrackerGetConfig(const std::string &key, const int bytesize, void* value) {
utils::TCPSocket tracker = this->ConnectTracker();
tracker.SendStr(std::string("get"));
tracker.SendStr(key);
tracker.RecvStr(value);
tracker.RecvAll(value, bytesize);
tracker.Close();
}

Expand Down
4 changes: 2 additions & 2 deletions src/allreduce_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class AllreduceBase : public IEngine {
* \param msg message to be printed in the tracker
*/
virtual void TrackerPrint(const std::string &msg);
virtual void TrackerSetConfig(const std::string &key, const std::string &value);
virtual void TrackerGetConfig(const std::string& key, std::string* value);
virtual void TrackerSetConfig(const std::string &key, const int bytesize, const void* value);
virtual void TrackerGetConfig(const std::string &key, const int bytesize, void* value);

/*! \brief get rank */
virtual int GetRank(void) const {
Expand Down
1 change: 1 addition & 0 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
if (prepare_fun != NULL) prepare_fun(prepare_arg);
return;
}

bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
Expand Down

0 comments on commit 2a28e5e

Please sign in to comment.