diff --git a/.gitmodules b/.gitmodules index 133ceb3889da..a78662c0e7ea 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "include/boost/compute"] path = compute url = https://github.com/boostorg/compute +[submodule "external_libs/fmt"] + path = external_libs/fmt + url = https://github.com/fmtlib/fmt.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 78c6c0d18efb..1f26b7428a42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -315,6 +315,7 @@ if(WIN32 AND (MINGW OR CYGWIN)) TARGET_LINK_LIBRARIES(_lightgbm IPHLPAPI) endif() + if(BUILD_FOR_R) if(MSVC) TARGET_LINK_LIBRARIES(_lightgbm ${LIBR_MSVC_CORE_LIBRARY}) @@ -323,6 +324,11 @@ if(BUILD_FOR_R) endif(MSVC) endif(BUILD_FOR_R) +# fmtlib/fmt +add_subdirectory(external_libs/fmt) +TARGET_LINK_LIBRARIES(lightgbm PUBLIC fmt::fmt) +TARGET_LINK_LIBRARIES(_lightgbm PUBLIC fmt::fmt) + install(TARGETS lightgbm _lightgbm RUNTIME DESTINATION ${CMAKE_INSTALL_PREFIX}/bin LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/lib diff --git a/external_libs/fmt b/external_libs/fmt new file mode 160000 index 000000000000..f674434a6700 --- /dev/null +++ b/external_libs/fmt @@ -0,0 +1 @@ +Subproject commit f674434a6700c14b938b4c8507abc104154a92d6 diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h index 07b8484b5577..c963afa2172d 100644 --- a/include/LightGBM/utils/common.h +++ b/include/LightGBM/utils/common.h @@ -5,6 +5,7 @@ #ifndef LIGHTGBM_UTILS_COMMON_FUN_H_ #define LIGHTGBM_UTILS_COMMON_FUN_H_ +#include "../../../fmt/include/fmt/core.h" #include #include @@ -408,13 +409,11 @@ inline static void Int32ToStr(int32_t value, char* buffer) { Uint32ToStr(u, buffer); } + + inline static void DoubleToStr(double value, char* buffer, size_t buffer_len) { - #ifdef _MSC_VER - int num_chars = sprintf_s(buffer, buffer_len, "%.17g", value); - #else - int num_chars = snprintf(buffer, buffer_len, "%.17g", value); - #endif - CHECK_GE(num_chars, 0); + const std::string s = fmt::format("{:.17g}", value); + s.copy(buffer, buffer_len); } inline static const char* SkipSpaceAndTab(const char* p) { @@ -1137,6 +1136,324 @@ class FunctionTimer { extern Common::Timer global_timer; + + +namespace Common2 { + + + template + inline static std::string Join(const std::vector& strs, const char* delimiter) { + if (strs.empty()) { + return std::string(""); + } + std::stringstream str_buf; + str_buf.imbue(std::locale("C")); + str_buf << std::setprecision(std::numeric_limits::digits10 + 2); + str_buf << strs[0]; + for (size_t i = 1; i < strs.size(); ++i) { + str_buf << delimiter; + str_buf << strs[i]; + } + return str_buf.str(); + } + + template<> + inline std::string Join(const std::vector& strs, const char* delimiter) { + if (strs.empty()) { + return std::string(""); + } + std::stringstream str_buf; + str_buf.imbue(std::locale("C")); + str_buf << std::setprecision(std::numeric_limits::digits10 + 2); + str_buf << static_cast(strs[0]); + for (size_t i = 1; i < strs.size(); ++i) { + str_buf << delimiter; + str_buf << static_cast(strs[i]); + } + return str_buf.str(); + } + + template + inline static std::string Join(const std::vector& strs, size_t start, size_t end, const char* delimiter) { + if (end - start <= 0) { + return std::string(""); + } + start = std::min(start, static_cast(strs.size()) - 1); + end = std::min(end, static_cast(strs.size())); + std::stringstream str_buf; + str_buf.imbue(std::locale("C")); + str_buf << std::setprecision(std::numeric_limits::digits10 + 2); + str_buf << strs[start]; + for (size_t i = start + 1; i < end; ++i) { + str_buf << delimiter; + str_buf << strs[i]; + } + return str_buf.str(); + } + +inline static const char* Atof(const char* p, double* out) { + int frac; + double sign, value, scale; + *out = NAN; + // Skip leading white space, if any. + while (*p == ' ') { + ++p; + } + // Get sign, if any. + sign = 1.0; + if (*p == '-') { + sign = -1.0; + ++p; + } else if (*p == '+') { + ++p; + } + + // is a number + if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') { + // Get digits before decimal point or exponent, if any. + for (value = 0.0; *p >= '0' && *p <= '9'; ++p) { + value = value * 10.0 + (*p - '0'); + } + + // Get digits after decimal point, if any. + if (*p == '.') { + double right = 0.0; + int nn = 0; + ++p; + while (*p >= '0' && *p <= '9') { + right = (*p - '0') + right * 10.0; + ++nn; + ++p; + } + value += right / Common::Pow(10.0, nn); + } + + // Handle exponent, if any. + frac = 0; + scale = 1.0; + if ((*p == 'e') || (*p == 'E')) { + uint32_t expon; + // Get sign of exponent, if any. + ++p; + if (*p == '-') { + frac = 1; + ++p; + } else if (*p == '+') { + ++p; + } + // Get digits of exponent, if any. + for (expon = 0; *p >= '0' && *p <= '9'; ++p) { + expon = expon * 10 + (*p - '0'); + } + if (expon > 308) expon = 308; + // Calculate scaling factor. + while (expon >= 50) { scale *= 1E50; expon -= 50; } + while (expon >= 8) { scale *= 1E8; expon -= 8; } + while (expon > 0) { scale *= 10.0; expon -= 1; } + } + // Return signed and scaled floating point result. + *out = sign * (frac ? (value / scale) : (value * scale)); + } else { + size_t cnt = 0; + while (*(p + cnt) != '\0' && *(p + cnt) != ' ' + && *(p + cnt) != '\t' && *(p + cnt) != ',' + && *(p + cnt) != '\n' && *(p + cnt) != '\r' + && *(p + cnt) != ':') { + ++cnt; + } + if (cnt > 0) { + std::string tmp_str(p, cnt); + std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower); + if (tmp_str == std::string("na") || tmp_str == std::string("nan") || + tmp_str == std::string("null")) { + *out = NAN; + } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) { + *out = sign * 1e308; + } else { + Log::Fatal("Unknown token %s in data file", tmp_str.c_str()); + } + p += cnt; + } + } + + while (*p == ' ') { + ++p; + } + + return p; +} + +template +struct __StringToTHelperFast { + const char* operator()(const char*p, T* out) const { + return LightGBM::Common::Atoi(p, out); + } +}; + +template +struct __StringToTHelperFast { + const char* operator()(const char*p, T* out) const { + double tmp = 0.0f; + auto ret = Atof(p, &tmp); + *out = static_cast(tmp); + return ret; + } +}; + +template +struct __StringToTHelper { + T operator()(const std::string& str) const { + T ret = 0; + Atoi(str.c_str(), &ret); + return ret; + } +}; + +template +struct __StringToTHelper { + T operator()(const std::string& str) const { + std::stringstream ss; + ss.imbue(std::locale("C")); + ss << str; + T tmp; + ss >> tmp; + return static_cast(tmp); + //return static_cast(std::stod(str)); + } +}; + +template +struct __TToStringHelperFast { + void operator()(T value, char* buffer, size_t) const { + LightGBM::Common::Int32ToStr(value, buffer); + } +}; + +template +struct __TToStringHelperFast { + void operator()(T value, char* buffer, size_t buf_len) const { + #ifdef _MSC_VER + int num_chars = sprintf_s(buffer, buf_len, "%g", value); + #else + int num_chars = snprintf(buffer, buf_len, "%g", value); + #endif + CHECK_GE(num_chars, 0); + } +}; + +template +struct __TToStringHelperFast { + void operator()(T value, char* buffer, size_t) const { + LightGBM::Common::Uint32ToStr(value, buffer); + } +}; + +inline static void DoubleToStr(double value, char* buffer, size_t buffer_len) { + #ifdef _MSC_VER + int num_chars = sprintf_s(buffer, buffer_len, "%.17g", value); + #else + int num_chars = snprintf(buffer, buffer_len, "%.17g", value); + #endif + CHECK_GE(num_chars, 0); +} + +template +inline static std::vector StringToArrayFast(const std::string& str, int n) { + if (n == 0) { + return std::vector(); + } + auto p_str = str.c_str(); + __StringToTHelperFast::value> helper; + std::vector ret(n); + for (int i = 0; i < n; ++i) { + p_str = helper(p_str, &ret[i]); + } + return ret; +} + +template +inline static std::vector StringToArray(const std::string& str, char delimiter) { + std::vector strs = LightGBM::Common::Split(str.c_str(), delimiter); + std::vector ret; + ret.reserve(strs.size()); + __StringToTHelper::value> helper; + for (const auto& s : strs) { + ret.push_back(helper(s)); + } + return ret; +} + +template +inline static std::vector StringToArray(const std::string& str, int n) { + if (n == 0) { + return std::vector(); + } + std::vector strs = LightGBM::Common::Split(str.c_str(), ' '); + CHECK_EQ(strs.size(), static_cast(n)); + std::vector ret; + ret.reserve(strs.size()); + __StringToTHelper::value> helper; + for (const auto& s : strs) { + ret.push_back(helper(s)); + } + return ret; +} + +template +inline static std::string ArrayToStringFast(const std::vector& arr, size_t n) { + if (arr.empty() || n == 0) { + return std::string(""); + } + __TToStringHelperFast::value, std::is_unsigned::value> helper; + const size_t buf_len = 16; + std::vector buffer(buf_len); + std::stringstream str_buf; + str_buf.imbue(std::locale("C")); + helper(arr[0], buffer.data(), buf_len); + str_buf << buffer.data(); + for (size_t i = 1; i < std::min(n, arr.size()); ++i) { + helper(arr[i], buffer.data(), buf_len); + str_buf << ' ' << buffer.data(); + } + return str_buf.str(); +} + +inline static std::string ArrayToString(const std::vector& arr, size_t n) { + if (arr.empty() || n == 0) { + return std::string(""); + } + const size_t buf_len = 32; + std::vector buffer(buf_len); + std::stringstream str_buf; + str_buf.imbue(std::locale("C")); + DoubleToStr(arr[0], buffer.data(), buf_len); + str_buf << buffer.data(); + for (size_t i = 1; i < std::min(n, arr.size()); ++i) { + DoubleToStr(arr[i], buffer.data(), buf_len); + str_buf << ' ' << buffer.data(); + } + return str_buf.str(); +} + + + + + +#include + +template +void cmp(std::vector a, std::vector b) { + if (a.size() != b.size()) { + Log::Fatal("Different array sizes! %d (expected=%d)", a.size(), b.size()); + } + + if (!std::equal(a.begin(), a.end(), b.begin())) + Log::Fatal("Different array contents!"); +} + +} // Namespace Common2 + + } // namespace LightGBM #endif // LightGBM_UTILS_COMMON_FUN_H_ diff --git a/include/LightGBM/utils/text_reader.h b/include/LightGBM/utils/text_reader.h index 638bb2683627..45b71e9adc36 100644 --- a/include/LightGBM/utils/text_reader.h +++ b/include/LightGBM/utils/text_reader.h @@ -38,6 +38,7 @@ class TextReader { Log::Fatal("Could not open %s", filename); } std::stringstream str_buf; + // Imbue C locale??? - Parameter? char read_c; size_t nread = reader->Read(&read_c, 1); while (nread == 1) { diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index 4eeb731f587f..a5f6247dbe26 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -18,6 +18,16 @@ namespace LightGBM { const char* kModelVersion = "v3"; + +//////////////////////////////////////////////////////////////////////////////////// +#include + +//////////////////////////////////////////////////////////////////////////////////// + + + + + std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_importance_type) const { std::stringstream str_buf; @@ -34,11 +44,11 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_ str_buf << "\"average_output\":" << (average_output_ ? "true" : "false") << ",\n"; - str_buf << "\"feature_names\":[\"" << Common::Join(feature_names_, "\",\"") + str_buf << "\"feature_names\":[\"" << Common2::Join(feature_names_, "\",\"") << "\"]," << '\n'; str_buf << "\"monotone_constraints\":[" - << Common::Join(monotone_constraints_, ",") << "]," << '\n'; + << Common2::Join(monotone_constraints_, ",") << "]," << '\n'; str_buf << "\"feature_infos\":" << "{"; bool first_obj = true; @@ -61,7 +71,7 @@ std::string GBDT::DumpModel(int start_iteration, int num_iteration, int feature_ auto min_idx = ArrayArgs::ArgMin(vals); json_str_buf << "{\"min_value\":" << vals[min_idx] << ","; json_str_buf << "\"max_value\":" << vals[max_idx] << ","; - json_str_buf << "\"values\":[" << Common::Join(vals, ",") << "]}"; + json_str_buf << "\"values\":[" << Common2::Join(vals, ",") << "]}"; } else { // unused feature continue; } @@ -325,14 +335,14 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int ss << "average_output" << '\n'; } - ss << "feature_names=" << Common::Join(feature_names_, " ") << '\n'; + ss << "feature_names=" << Common2::Join(feature_names_, " ") << '\n'; if (monotone_constraints_.size() != 0) { - ss << "monotone_constraints=" << Common::Join(monotone_constraints_, " ") + ss << "monotone_constraints=" << Common2::Join(monotone_constraints_, " ") << '\n'; } - ss << "feature_infos=" << Common::Join(feature_infos_, " ") << '\n'; + ss << "feature_infos=" << Common2::Join(feature_infos_, " ") << '\n'; int num_used_model = static_cast(models_.size()); int total_iteration = num_used_model / num_tree_per_iteration_; @@ -356,7 +366,7 @@ std::string GBDT::SaveModelToString(int start_iteration, int num_iteration, int tree_sizes[idx] = tree_strs[idx].size(); } - ss << "tree_sizes=" << Common::Join(tree_sizes, " ") << '\n'; + ss << "tree_sizes=" << Common2::Join(tree_sizes, " ") << '\n'; ss << '\n'; for (int i = 0; i < num_used_model - start_model; ++i) { diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 8e5104f168eb..df0d41ab4682 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -225,34 +225,34 @@ std::string Tree::ToString() const { str_buf << "num_leaves=" << num_leaves_ << '\n'; str_buf << "num_cat=" << num_cat_ << '\n'; str_buf << "split_feature=" - << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n'; str_buf << "split_gain=" - << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n'; str_buf << "threshold=" - << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToString(threshold_, num_leaves_ - 1) << '\n'; str_buf << "decision_type=" - << Common::ArrayToStringFast(Common::ArrayCast(decision_type_), num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(Common::ArrayCast(decision_type_), num_leaves_ - 1) << '\n'; str_buf << "left_child=" - << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n'; str_buf << "right_child=" - << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n'; str_buf << "leaf_value=" - << Common::ArrayToString(leaf_value_, num_leaves_) << '\n'; + << Common2::ArrayToString(leaf_value_, num_leaves_) << '\n'; str_buf << "leaf_weight=" - << Common::ArrayToString(leaf_weight_, num_leaves_) << '\n'; + << Common2::ArrayToString(leaf_weight_, num_leaves_) << '\n'; str_buf << "leaf_count=" - << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n'; + << Common2::ArrayToStringFast(leaf_count_, num_leaves_) << '\n'; str_buf << "internal_value=" - << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n'; str_buf << "internal_weight=" - << Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n'; str_buf << "internal_count=" - << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n'; + << Common2::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n'; if (num_cat_ > 0) { str_buf << "cat_boundaries=" - << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n'; + << Common2::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n'; str_buf << "cat_threshold=" - << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n'; + << Common2::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n'; } str_buf << "shrinkage=" << shrinkage_ << '\n'; str_buf << '\n'; @@ -493,6 +493,8 @@ std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const { return str_buf.str(); } +using Common2::cmp; + Tree::Tree(const char* str, size_t* used_len) { auto p = str; std::unordered_map key_vals; @@ -513,7 +515,7 @@ Tree::Tree(const char* str, size_t* used_len) { } *used_len = p - str; - if (key_vals.count("num_leaves") <= 0) { + if (key_vals.count("num_leaves") <= 0) { Log::Info("num_leaves"); Log::Fatal("Tree model should contain num_leaves field"); } @@ -525,95 +527,109 @@ Tree::Tree(const char* str, size_t* used_len) { Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_); - if (key_vals.count("leaf_value")) { - leaf_value_ = Common::StringToArray(key_vals["leaf_value"], num_leaves_); + if (key_vals.count("leaf_value")) { Log::Info("leaf_value"); + leaf_value_ = Common2::StringToArray(key_vals["leaf_value"], num_leaves_); + cmp(leaf_value_, Common::StringToArray(key_vals["leaf_value"], num_leaves_)); } else { Log::Fatal("Tree model string format error, should contain leaf_value field"); } - if (key_vals.count("shrinkage")) { - Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_); + if (key_vals.count("shrinkage")) { Log::Info("shrinkage"); + Common2::Atof(key_vals["shrinkage"].c_str(), &shrinkage_); } else { shrinkage_ = 1.0f; } if (num_leaves_ <= 1) { return; } - if (key_vals.count("left_child")) { - left_child_ = Common::StringToArrayFast(key_vals["left_child"], num_leaves_ - 1); + if (key_vals.count("left_child")) { Log::Info("left_child"); + left_child_ = Common2::StringToArrayFast(key_vals["left_child"], num_leaves_ - 1); + cmp(left_child_, Common::StringToArrayFast(key_vals["left_child"], num_leaves_ - 1)); } else { Log::Fatal("Tree model string format error, should contain left_child field"); } - if (key_vals.count("right_child")) { - right_child_ = Common::StringToArrayFast(key_vals["right_child"], num_leaves_ - 1); + if (key_vals.count("right_child")) { Log::Info("right_child"); + right_child_ = Common2::StringToArrayFast(key_vals["right_child"], num_leaves_ - 1); + cmp(right_child_, Common::StringToArrayFast(key_vals["right_child"], num_leaves_ - 1)); } else { Log::Fatal("Tree model string format error, should contain right_child field"); } - if (key_vals.count("split_feature")) { - split_feature_ = Common::StringToArrayFast(key_vals["split_feature"], num_leaves_ - 1); + if (key_vals.count("split_feature")) { Log::Info("split_feature"); + split_feature_ = Common2::StringToArrayFast(key_vals["split_feature"], num_leaves_ - 1); + cmp(split_feature_, Common::StringToArrayFast(key_vals["split_feature"], num_leaves_ - 1)); } else { Log::Fatal("Tree model string format error, should contain split_feature field"); } - if (key_vals.count("threshold")) { - threshold_ = Common::StringToArray(key_vals["threshold"], num_leaves_ - 1); + if (key_vals.count("threshold")) { Log::Info("threshold"); + threshold_ = Common2::StringToArray(key_vals["threshold"], num_leaves_ - 1); + cmp(threshold_, Common::StringToArray(key_vals["threshold"], num_leaves_ - 1)); } else { Log::Fatal("Tree model string format error, should contain threshold field"); } - if (key_vals.count("split_gain")) { - split_gain_ = Common::StringToArrayFast(key_vals["split_gain"], num_leaves_ - 1); + if (key_vals.count("split_gain")) { Log::Info("split_gain"); + split_gain_ = Common2::StringToArrayFast(key_vals["split_gain"], num_leaves_ - 1); + cmp(split_gain_, Common::StringToArrayFast(key_vals["split_gain"], num_leaves_ - 1)); } else { split_gain_.resize(num_leaves_ - 1); } - if (key_vals.count("internal_count")) { - internal_count_ = Common::StringToArrayFast(key_vals["internal_count"], num_leaves_ - 1); + if (key_vals.count("internal_count")) { Log::Info("internal_count"); + internal_count_ = Common2::StringToArrayFast(key_vals["internal_count"], num_leaves_ - 1); + cmp(internal_count_, Common::StringToArrayFast(key_vals["internal_count"], num_leaves_ - 1)); } else { internal_count_.resize(num_leaves_ - 1); } - if (key_vals.count("internal_value")) { - internal_value_ = Common::StringToArrayFast(key_vals["internal_value"], num_leaves_ - 1); + if (key_vals.count("internal_value")) { Log::Info("internal_value"); + internal_value_ = Common2::StringToArrayFast(key_vals["internal_value"], num_leaves_ - 1); + cmp(internal_value_, Common::StringToArrayFast(key_vals["internal_value"], num_leaves_ - 1)); } else { internal_value_.resize(num_leaves_ - 1); } - if (key_vals.count("internal_weight")) { - internal_weight_ = Common::StringToArrayFast(key_vals["internal_weight"], num_leaves_ - 1); + if (key_vals.count("internal_weight")) { Log::Info("internal_weight"); + internal_weight_ = Common2::StringToArrayFast(key_vals["internal_weight"], num_leaves_ - 1); + cmp(internal_weight_, Common::StringToArrayFast(key_vals["internal_weight"], num_leaves_ - 1)); } else { internal_weight_.resize(num_leaves_ - 1); } - if (key_vals.count("leaf_weight")) { - leaf_weight_ = Common::StringToArray(key_vals["leaf_weight"], num_leaves_); + if (key_vals.count("leaf_weight")) { Log::Info("leaf_weight"); + leaf_weight_ = Common2::StringToArray(key_vals["leaf_weight"], num_leaves_); + cmp(leaf_weight_, Common::StringToArray(key_vals["leaf_weight"], num_leaves_)); } else { leaf_weight_.resize(num_leaves_); } - if (key_vals.count("leaf_count")) { - leaf_count_ = Common::StringToArrayFast(key_vals["leaf_count"], num_leaves_); + if (key_vals.count("leaf_count")) { Log::Info("leaf_count"); + leaf_count_ = Common2::StringToArrayFast(key_vals["leaf_count"], num_leaves_); + cmp(leaf_count_, Common::StringToArrayFast(key_vals["leaf_count"], num_leaves_)); } else { leaf_count_.resize(num_leaves_); } - if (key_vals.count("decision_type")) { - decision_type_ = Common::StringToArrayFast(key_vals["decision_type"], num_leaves_ - 1); + if (key_vals.count("decision_type")) { Log::Info("decision_type"); + decision_type_ = Common2::StringToArrayFast(key_vals["decision_type"], num_leaves_ - 1); + cmp(decision_type_, Common::StringToArrayFast(key_vals["decision_type"], num_leaves_ - 1)); } else { decision_type_ = std::vector(num_leaves_ - 1, 0); } if (num_cat_ > 0) { - if (key_vals.count("cat_boundaries")) { - cat_boundaries_ = Common::StringToArrayFast(key_vals["cat_boundaries"], num_cat_ + 1); + if (key_vals.count("cat_boundaries")) { Log::Info("cat_boundaries"); + cat_boundaries_ = Common2::StringToArrayFast(key_vals["cat_boundaries"], num_cat_ + 1); + cmp(cat_boundaries_, Common::StringToArrayFast(key_vals["cat_boundaries"], num_cat_ + 1)); } else { Log::Fatal("Tree model should contain cat_boundaries field."); } - if (key_vals.count("cat_threshold")) { - cat_threshold_ = Common::StringToArrayFast(key_vals["cat_threshold"], cat_boundaries_.back()); + if (key_vals.count("cat_threshold")) { Log::Info("cat_threshold"); + cat_threshold_ = Common2::StringToArrayFast(key_vals["cat_threshold"], cat_boundaries_.back()); + cmp(cat_threshold_, Common::StringToArrayFast(key_vals["cat_threshold"], cat_boundaries_.back())); } else { Log::Fatal("Tree model should contain cat_threshold field"); }