Skip to content

Commit

Permalink
remote StatWrapper
Browse files Browse the repository at this point in the history
test=develop
  • Loading branch information
hutuxian committed May 14, 2020
1 parent 2d77026 commit 2a4865d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 36 deletions.
17 changes: 16 additions & 1 deletion paddle/fluid/platform/monitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,20 @@
#include <utility>

namespace paddle {
namespace platform {} // namespace platform
namespace platform {
#define DEFINE_FLOAT_STATUS(item) StatValue<float> _##item(#item);

#define DEFINE_INT_STATUS(item) StatValue<int64_t> _##item(#item);

DEFINE_INT_STATUS(STAT_total_feasign_num_in_mem)
DEFINE_INT_STATUS(STAT_gpu0_mem_size)
DEFINE_INT_STATUS(STAT_gpu1_mem_size)
DEFINE_INT_STATUS(STAT_gpu2_mem_size)
DEFINE_INT_STATUS(STAT_gpu3_mem_size)
DEFINE_INT_STATUS(STAT_gpu4_mem_size)
DEFINE_INT_STATUS(STAT_gpu5_mem_size)
DEFINE_INT_STATUS(STAT_gpu6_mem_size)
DEFINE_INT_STATUS(STAT_gpu7_mem_size)

} // namespace platform
} // namespace paddle
60 changes: 28 additions & 32 deletions paddle/fluid/platform/monitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class StatValue {
std::mutex mu_;
// We use lock rather than atomic for generic values
public:
explicit StatValue(const std::string& n) {
StatRegistry<T>::Instance().add(n, this);
}
T increase(T inc) {
std::lock_guard<std::mutex> lock(mu_);
return v_ += inc;
Expand Down Expand Up @@ -65,22 +68,31 @@ class StatRegistry {
public:
~StatRegistry<T>() {}

static StatRegistry<T>& get() {
static StatRegistry<T>& Instance() {
static StatRegistry<T> r;
return r;
}

StatValue<T>* add(const std::string& name) {
StatValue<T>* get(const std::string& name) {
std::lock_guard<std::mutex> lg(mutex_);
auto it = stats_.find(name);
if (it != stats_.end()) {
return it->second;
} else {
printf("not register\n");
return nullptr;
}
}
int add(const std::string& name, StatValue<T>* stat) {
std::lock_guard<std::mutex> lg(mutex_);
auto it = stats_.find(name);
if (it != stats_.end()) {
return it->second.get();
// LOG(WARNING) << name << " has been registerd before, please check it.";
return -1;
}
auto v = std::unique_ptr<StatValue<T>>(new StatValue<T>);
VLOG(0) << "Register Stats: " << name;
auto value = v.get();
stats_.insert(std::make_pair(name, std::move(v)));
return value;
stats_.insert(std::make_pair(name, stat));
// How to print log before Init VLOG?
// VLOG(4) << "STAT Register: " << name;
return 0;
}

void publish(std::vector<ExportedStatValue<T>>& exported, // NOLINT
Expand All @@ -103,43 +115,27 @@ class StatRegistry {

private:
std::mutex mutex_;
std::unordered_map<std::string, std::unique_ptr<StatValue<T>>> stats_;
};

template <typename T>
class Stat {
public:
explicit Stat(const std::string& n)
: name(n), value_(StatRegistry<T>::get().add(n)) {}

T increase(T inc) { return value_->increase(inc); }
T decrease(T inc) { return value_->decrease(inc); }
T reset(T value) { return value_->reset(value); }
T get() const { return value_->get(); }

private:
std::string name;
StatValue<T>* value_;
std::unordered_map<std::string, StatValue<T>*> stats_;
};

// Because we only support these two types in pybind
#define REGISTER_FLOAT_STATUS(item) static Stat<float> _##item(#item);
#define REGISTER_FLOAT_STATUS(item) extern StatValue<float> _##item;

#define REGISTER_INT_STATUS(item) static Stat<int64_t> _##item(#item);
#define REGISTER_INT_STATUS(item) extern StatValue<int64_t> _##item;

#define STAT_ADD(item, t) paddle::platform::_##item.increase(t)
#define STAT_SUB(item, t) paddle::platform::_##item.decrease(t)

// Support add stat value by string
#define STAT_INT_ADD(item, t) \
paddle::platform::StatRegistry<int64_t>::get().add(item)->increase(t)
paddle::platform::StatRegistry<int64_t>::Instance().get(item)->increase(t)
#define STAT_INT_SUB(item, t) \
paddle::platform::StatRegistry<int64_t>::get().add(item)->decrease(t)
paddle::platform::StatRegistry<int64_t>::Instance().get(item)->decrease(t)

#define STAT_FLOAT_ADD(item, t) \
paddle::platform::StatRegistry<float>::get().add(item)->increase(t)
paddle::platform::StatRegistry<float>::Instance().get(item)->increase(t)
#define STAT_FLOAT_SUB(item, t) \
paddle::platform::StatRegistry<float>::get().add(item)->decrease(t)
paddle::platform::StatRegistry<float>::Instance().get(item)->decrease(t)

#define STAT_RESET(item, t) paddle::platform::_##item.reset(t)

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1501,7 +1501,7 @@ All parameter, weight, gradient are variables in Paddle.

m.def("get_float_stats", []() {
std::vector<paddle::platform::ExportedStatValue<float>> float_stats;
paddle::platform::StatRegistry<float>::get().publish(float_stats);
paddle::platform::StatRegistry<float>::Instance().publish(float_stats);
std::unordered_map<std::string, float> stats_map;
for (const auto &stat : float_stats) {
stats_map[stat.key] = stat.value;
Expand All @@ -1510,8 +1510,8 @@ All parameter, weight, gradient are variables in Paddle.
});
m.def("get_int_stats", []() {
std::vector<paddle::platform::ExportedStatValue<int64_t>> int_stats;
paddle::platform::StatRegistry<int64_t>::get().publish(int_stats);
std::unordered_map<std::string, int> stats_map;
paddle::platform::StatRegistry<int64_t>::Instance().publish(int_stats);
std::unordered_map<std::string, int64_t> stats_map;
for (const auto &stat : int_stats) {
stats_map[stat.key] = stat.value;
}
Expand Down

0 comments on commit 2a4865d

Please sign in to comment.