From 2e823eb2f207d7c850429ef8a947a7c2a587673b Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 29 Aug 2020 00:48:11 -0700 Subject: [PATCH 01/38] Make Model and Tree classes into template classes --- CMakeLists.txt | 2 + include/treelite/annotator.h | 3 + include/treelite/base.h | 3 +- include/treelite/tree.h | 81 +++++-- include/treelite/tree_impl.h | 280 ++++++++++++++++++----- src/CMakeLists.txt | 5 +- src/annotator.cc | 42 ++-- src/c_api/c_api.cc | 22 +- src/compiler/ast/ast.h | 32 +-- src/compiler/ast/build.cc | 53 +++-- src/compiler/ast/builder.cc | 21 ++ src/compiler/ast/builder.h | 17 +- src/compiler/ast/dump.cc | 9 +- src/compiler/ast/fold_code.cc | 23 +- src/compiler/ast/is_categorical_array.cc | 9 +- src/compiler/ast/load_data_counts.cc | 12 +- src/compiler/ast/quantize.cc | 40 ++-- src/compiler/ast/split.cc | 11 +- src/compiler/ast_native.cc | 83 ++++--- src/compiler/common/code_folding_util.h | 21 +- src/compiler/failsafe.cc | 34 +-- src/compiler/native/pred_transform.h | 22 +- src/compiler/pred_transform.cc | 7 +- src/compiler/pred_transform.h | 3 +- src/frontend/lightgbm.cc | 51 ++--- src/frontend/xgboost.cc | 43 ++-- src/reference_serializer.cc | 8 +- tests/cpp/CMakeLists.txt | 4 +- 28 files changed, 640 insertions(+), 301 deletions(-) create mode 100644 src/compiler/ast/builder.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e44a113e..7c766b2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ set (CMAKE_FIND_NO_INSTALL_PREFIX TRUE FORCE) cmake_minimum_required (VERSION 3.14) project(treelite LANGUAGES CXX C VERSION 0.93) +set(CMAKE_CXX_STANDARD 14) + # check MSVC version if(MSVC) if(MSVC_VERSION LESS 1910) diff --git a/include/treelite/annotator.h b/include/treelite/annotator.h index 235c4d21..accdc46a 100644 --- a/include/treelite/annotator.h +++ b/include/treelite/annotator.h @@ -16,6 +16,9 @@ namespace treelite { /*! \brief branch annotator class */ class BranchAnnotator { public: + template + void AnnotateImpl(const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, int nthread, int verbose); /*! * \brief annotate branches in a given model using frequency patterns in the * training data. The annotation can be accessed through Get() method. diff --git a/include/treelite/base.h b/include/treelite/base.h index f9f352a0..b9a134f3 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -56,7 +56,8 @@ inline std::string OpName(Operator op) { * \param rhs float on the right hand side * \return whether [lhs] [op] [rhs] is true or not */ -inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) { +template +inline bool CompareWithOp(ThresholdType lhs, Operator op, ThresholdType rhs) { switch (op) { case Operator::kEQ: return lhs == rhs; case Operator::kLT: return lhs < rhs; diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 6d9a085f..65be0f2c 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -77,6 +77,7 @@ class ContiguousArray { }; /*! \brief in-memory representation of a decision tree */ +template class Tree { public: /*! \brief tree node */ @@ -86,8 +87,8 @@ class Tree { inline void Init(); /*! \brief store either leaf value or decision threshold */ union Info { - tl_float leaf_value; // for leaf nodes - tl_float threshold; // for non-leaf nodes + LeafOutputType leaf_value; // for leaf nodes + ThresholdType threshold; // for non-leaf nodes }; /*! \brief pointer to left and right children */ int32_t cleft_, cright_; @@ -138,14 +139,15 @@ class Tree { }; static_assert(std::is_pod::value, "Node must be a POD type"); - static_assert(sizeof(Node) == 48, "Node must be 48 bytes"); + // TODO(hcho3): Add back size check + //static_assert(sizeof(Node) == 48, "Node must be 48 bytes"); Tree() = default; ~Tree() = default; Tree(const Tree&) = delete; Tree& operator=(const Tree&) = delete; - Tree(Tree&&) = default; - Tree& operator=(Tree&&) = default; + Tree(Tree&&) noexcept = default; + Tree& operator=(Tree&&) noexcept = default; inline Tree Clone() const; inline std::vector GetPyBuffer(); @@ -154,7 +156,7 @@ class Tree { private: // vector of nodes ContiguousArray nodes_; - ContiguousArray leaf_vector_; + ContiguousArray leaf_vector_; ContiguousArray leaf_vector_offset_; ContiguousArray left_categories_; ContiguousArray left_categories_offset_; @@ -214,12 +216,12 @@ class Tree { * \brief get leaf value of the leaf node * \param nid ID of node being queried */ - inline tl_float LeafValue(int nid) const; + inline LeafOutputType LeafValue(int nid) const; /*! * \brief get leaf vector of the leaf node; useful for multi-class random forest classifier * \param nid ID of node being queried */ - inline std::vector LeafVector(int nid) const; + inline std::vector LeafVector(int nid) const; /*! * \brief tests whether the leaf node has a non-empty leaf vector * \param nid ID of node being queried @@ -229,7 +231,7 @@ class Tree { * \brief get threshold of the node * \param nid ID of node being queried */ - inline tl_float Threshold(int nid) const; + inline ThresholdType Threshold(int nid) const; /*! * \brief get comparison operator * \param nid ID of node being queried @@ -295,7 +297,7 @@ class Tree { * \param cmp comparison operator to compare between feature value and * threshold */ - inline void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, + inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp); /*! * \brief create a categorical split @@ -314,13 +316,13 @@ class Tree { * \param nid ID of node being updated * \param value leaf value */ - inline void SetLeaf(int nid, tl_float value); + inline void SetLeaf(int nid, LeafOutputType value); /*! * \brief set the leaf vector of the node; useful for multi-class random forest classifier * \param nid ID of node being updated * \param leaf_vector leaf vector */ - inline void SetLeafVector(int nid, const std::vector& leaf_vector); + inline void SetLeafVector(int nid, const std::vector& leaf_vector); /*! * \brief set the hessian sum of the node * \param nid ID of node being updated @@ -406,9 +408,10 @@ inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg); /*! \brief thin wrapper for tree ensemble model */ -struct Model { +template +struct ModelImpl { /*! \brief member trees */ - std::vector trees; + std::vector> trees; /*! * \brief number of features used for the model. * It is assumed that all feature indices are between 0 and [num_feature]-1. @@ -424,18 +427,54 @@ struct Model { ModelParam param; /*! \brief disable copy; use default move */ - Model() = default; - ~Model() = default; - Model(const Model&) = delete; - Model& operator=(const Model&) = delete; - Model(Model&&) = default; - Model& operator=(Model&&) = default; + ModelImpl() = default; + ~ModelImpl() = default; + ModelImpl(const ModelImpl&) = delete; + ModelImpl& operator=(const ModelImpl&) = delete; + ModelImpl(ModelImpl&&) noexcept = default; + ModelImpl& operator=(ModelImpl&&) noexcept = default; void ReferenceSerialize(dmlc::Stream* fo) const; inline std::vector GetPyBuffer(); inline void InitFromPyBuffer(std::vector frames); - inline Model Clone() const; + inline ModelImpl Clone() const; +}; + +enum class ModelType : uint16_t { + // Threshold type, + kInvalid = 0, + kFloat32ThresholdUInt32LeafOutput = 1, + kFloat32ThresholdFloat32LeafOutput = 2, + kFloat64ThresholdUint32LeafOutput = 3, + kFloat64ThresholdFloat64LeafOutput = 4 +}; + +struct Model { + private: + std::shared_ptr handle_; + ModelType type_; + public: + template + inline static Model Create(); + template + inline ModelImpl& GetImpl(); + template + inline const ModelImpl& GetImpl() const; + inline ModelType GetModelType() const; + template + inline auto Dispatch(Func func) const; + template + inline auto Dispatch(Func func); + inline ModelParam GetParam() const; + inline int GetNumFeature() const; + inline int GetNumOutputGroup() const; + inline bool GetRandomForestFlag() const; + inline size_t GetNumTree() const; + inline void SetTreeLimit(size_t limit); + inline void ReferenceSerialize(dmlc::Stream* fo) const; + inline std::vector GetPyBuffer(); + inline void InitFromPyBuffer(std::vector frames); }; } // namespace treelite diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 374523fe..b81fe7e7 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -362,8 +362,9 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { constexpr size_t kNumFramePerTree = 6; +template inline std::vector -Tree::GetPyBuffer() { +Tree::GetPyBuffer() { return { GetPyBufferFromScalar(&num_nodes), GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}"), @@ -374,8 +375,9 @@ Tree::GetPyBuffer() { }; } +template inline void -Tree::InitFromPyBuffer(std::vector frames) { +Tree::InitFromPyBuffer(std::vector frames) { size_t frame_id = 0; InitScalarFromPyBuffer(&num_nodes, frames[frame_id++]); InitArrayFromPyBuffer(&nodes_, frames[frame_id++]); @@ -391,8 +393,9 @@ Tree::InitFromPyBuffer(std::vector frames) { } } +template inline std::vector -Model::GetPyBuffer() { +ModelImpl::GetPyBuffer() { /* Header */ std::vector frames{ GetPyBufferFromScalar(&num_feature), @@ -409,8 +412,9 @@ Model::GetPyBuffer() { return frames; } +template inline void -Model::InitFromPyBuffer(std::vector frames) { +ModelImpl::InitFromPyBuffer(std::vector frames) { /* Header */ size_t frame_id = 0; InitScalarFromPyBuffer(&num_feature, frames[frame_id++]); @@ -431,11 +435,12 @@ Model::InitFromPyBuffer(std::vector frames) { } } -inline void Tree::Node::Init() { +template +inline void Tree::Node::Init() { cleft_ = cright_ = -1; sindex_ = 0; - info_.leaf_value = 0.0f; - info_.threshold = 0.0f; + info_.leaf_value = static_cast(0); + info_.threshold = static_cast(0); data_count_ = 0; sum_hess_ = gain_ = 0.0; missing_category_to_zero_ = false; @@ -445,8 +450,9 @@ inline void Tree::Node::Init() { pad_ = 0; } +template inline int -Tree::AllocNode() { +Tree::AllocNode() { int nd = num_nodes++; if (nodes_.Size() != static_cast(nd)) { throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes"); @@ -460,9 +466,10 @@ Tree::AllocNode() { return nd; } -inline Tree -Tree::Clone() const { - Tree tree; +template +inline Tree +Tree::Clone() const { + Tree tree; tree.num_nodes = num_nodes; tree.nodes_ = nodes_.Clone(); tree.leaf_vector_ = leaf_vector_.Clone(); @@ -472,8 +479,9 @@ Tree::Clone() const { return tree; } +template inline void -Tree::Init() { +Tree::Init() { num_nodes = 1; leaf_vector_.Clear(); leaf_vector_offset_.Resize(2, 0); @@ -484,16 +492,18 @@ Tree::Init() { SetLeaf(0, 0.0f); } +template inline void -Tree::AddChilds(int nid) { +Tree::AddChilds(int nid) { const int cleft = this->AllocNode(); const int cright = this->AllocNode(); nodes_[nid].cleft_ = cleft; nodes_[nid].cright_ = cright; } +template inline std::vector -Tree::GetCategoricalFeatures() const { +Tree::GetCategoricalFeatures() const { std::unordered_map tmp; for (int nid = 0; nid < num_nodes; ++nid) { const SplitFeatureType type = SplitType(nid); @@ -520,70 +530,82 @@ Tree::GetCategoricalFeatures() const { return result; } +template inline int -Tree::LeftChild(int nid) const { +Tree::LeftChild(int nid) const { return nodes_[nid].cleft_; } +template inline int -Tree::RightChild(int nid) const { +Tree::RightChild(int nid) const { return nodes_[nid].cright_; } +template inline int -Tree::DefaultChild(int nid) const { +Tree::DefaultChild(int nid) const { return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid); } +template inline uint32_t -Tree::SplitIndex(int nid) const { +Tree::SplitIndex(int nid) const { return (nodes_[nid].sindex_ & ((1U << 31U) - 1U)); } +template inline bool -Tree::DefaultLeft(int nid) const { +Tree::DefaultLeft(int nid) const { return (nodes_[nid].sindex_ >> 31U) != 0; } +template inline bool -Tree::IsLeaf(int nid) const { +Tree::IsLeaf(int nid) const { return nodes_[nid].cleft_ == -1; } -inline tl_float -Tree::LeafValue(int nid) const { +template +inline LeafOutputType +Tree::LeafValue(int nid) const { return (nodes_[nid].info_).leaf_value; } -inline std::vector -Tree::LeafVector(int nid) const { +template +inline std::vector +Tree::LeafVector(int nid) const { if (nid > leaf_vector_offset_.Size()) { throw std::runtime_error("nid too large"); } - return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], - &leaf_vector_[leaf_vector_offset_[nid + 1]]); + return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], + &leaf_vector_[leaf_vector_offset_[nid + 1]]); } +template inline bool -Tree::HasLeafVector(int nid) const { +Tree::HasLeafVector(int nid) const { if (nid > leaf_vector_offset_.Size()) { throw std::runtime_error("nid too large"); } return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1]; } -inline tl_float -Tree::Threshold(int nid) const { +template +inline ThresholdType +Tree::Threshold(int nid) const { return (nodes_[nid].info_).threshold; } +template inline Operator -Tree::ComparisonOp(int nid) const { +Tree::ComparisonOp(int nid) const { return nodes_[nid].cmp_; } +template inline std::vector -Tree::LeftCategories(int nid) const { +Tree::LeftCategories(int nid) const { if (nid > left_categories_offset_.Size()) { throw std::runtime_error("nid too large"); } @@ -591,49 +613,58 @@ Tree::LeftCategories(int nid) const { &left_categories_[left_categories_offset_[nid + 1]]); } +template inline SplitFeatureType -Tree::SplitType(int nid) const { +Tree::SplitType(int nid) const { return nodes_[nid].split_type_; } +template inline bool -Tree::HasDataCount(int nid) const { +Tree::HasDataCount(int nid) const { return nodes_[nid].data_count_present_; } +template inline uint64_t -Tree::DataCount(int nid) const { +Tree::DataCount(int nid) const { return nodes_[nid].data_count_; } +template inline bool -Tree::HasSumHess(int nid) const { +Tree::HasSumHess(int nid) const { return nodes_[nid].sum_hess_present_; } +template inline double -Tree::SumHess(int nid) const { +Tree::SumHess(int nid) const { return nodes_[nid].sum_hess_; } +template inline bool -Tree::HasGain(int nid) const { +Tree::HasGain(int nid) const { return nodes_[nid].gain_present_; } +template inline double -Tree::Gain(int nid) const { +Tree::Gain(int nid) const { return nodes_[nid].gain_; } +template inline bool -Tree::MissingCategoryToZero(int nid) const { +Tree::MissingCategoryToZero(int nid) const { return nodes_[nid].missing_category_to_zero_; } +template inline void -Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, - bool default_left, Operator cmp) { +Tree::SetNumericalSplit( + int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) { Node& node = nodes_[nid]; if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); @@ -645,10 +676,11 @@ Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, node.split_type_ = SplitFeatureType::kNumerical; } +template inline void -Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, - bool missing_category_to_zero, - const std::vector& node_left_categories) { +Tree::SetCategoricalSplit( + int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, + const std::vector& node_left_categories) { if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); } @@ -678,8 +710,9 @@ Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, node.missing_category_to_zero_ = missing_category_to_zero; } +template inline void -Tree::SetLeaf(int nid, tl_float value) { +Tree::SetLeaf(int nid, LeafOutputType value) { Node& node = nodes_[nid]; (node.info_).leaf_value = value; node.cleft_ = -1; @@ -687,8 +720,10 @@ Tree::SetLeaf(int nid, tl_float value) { node.split_type_ = SplitFeatureType::kNone; } +template inline void -Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { +Tree::SetLeafVector( + int nid, const std::vector& node_leaf_vector) { const size_t end_oft = leaf_vector_offset_.Back(); const size_t new_end_oft = end_oft + node_leaf_vector.size(); if (end_oft != leaf_vector_.Size()) { @@ -712,31 +747,35 @@ Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { node.split_type_ = SplitFeatureType::kNone; } +template inline void -Tree::SetSumHess(int nid, double sum_hess) { +Tree::SetSumHess(int nid, double sum_hess) { Node& node = nodes_[nid]; node.sum_hess_ = sum_hess; node.sum_hess_present_ = true; } +template inline void -Tree::SetDataCount(int nid, uint64_t data_count) { +Tree::SetDataCount(int nid, uint64_t data_count) { Node& node = nodes_[nid]; node.data_count_ = data_count; node.data_count_present_ = true; } +template inline void -Tree::SetGain(int nid, double gain) { +Tree::SetGain(int nid, double gain) { Node& node = nodes_[nid]; node.gain_ = gain; node.gain_present_ = true; } -inline Model -Model::Clone() const { - Model model; - for (const Tree& t : trees) { +template +inline ModelImpl +ModelImpl::Clone() const { + ModelImpl model; + for (const Tree& t : trees) { model.trees.push_back(t.Clone()); } model.num_feature = num_feature; @@ -746,6 +785,141 @@ Model::Clone() const { return model; } +template +inline ModelImpl& +Model::GetImpl() { + return *static_cast*>(handle_.get()); +} + +template +inline const ModelImpl& +Model::GetImpl() const { + return *static_cast*>(handle_.get()); +} + +inline ModelType +Model::GetModelType() const { + return type_; +} + +template +inline Model +Model::Create() { + Model model; + model.handle_.reset(new ModelImpl()); + model.type_ = ModelType::kInvalid; + + const char* error_msg = "Unsupported combination of ThresholdType and LeafOutputType"; + static_assert(std::is_same::value + || std::is_same::value, + "ThresholdType should be either float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "LeafOutputType should be uint32, float32 or float64"); + if (std::is_same::value) { + if (std::is_same::value) { + model.type_ = ModelType::kFloat32ThresholdUInt32LeafOutput; + } else if (std::is_same::value) { + model.type_ = ModelType::kFloat32ThresholdFloat32LeafOutput; + } else { + throw std::runtime_error(error_msg); + } + } else if (std::is_same::value) { + if (std::is_same::value) { + model.type_ = ModelType::kFloat64ThresholdUint32LeafOutput; + } else if (std::is_same::value) { + model.type_ = ModelType::kFloat64ThresholdFloat64LeafOutput; + } else { + throw std::runtime_error(error_msg); + } + } else { + throw std::runtime_error(error_msg); + } + return model; +} + +template +inline auto +Model::Dispatch(Func func) const { + switch(type_) { + case ModelType::kFloat32ThresholdUInt32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat32ThresholdFloat32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat64ThresholdUint32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat64ThresholdFloat64LeafOutput: + return func(GetImpl()); + default: + throw std::runtime_error("Unknown type name"); + return func(GetImpl()); // avoid "missing return" warning + } +} + +template +inline auto +Model::Dispatch(Func func) { + switch(type_) { + case ModelType::kFloat32ThresholdUInt32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat32ThresholdFloat32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat64ThresholdUint32LeafOutput: + return func(GetImpl()); + case ModelType::kFloat64ThresholdFloat64LeafOutput: + return func(GetImpl()); + default: + throw std::runtime_error("Unknown type name"); + return func(GetImpl()); // avoid "missing return" warning + } +} + +inline ModelParam +Model::GetParam() const { + return Dispatch([](const auto& handle) { return handle.param; }); +} + +inline int +Model::GetNumFeature() const { + return Dispatch([](const auto& handle) { return handle.num_feature; }); +} + +inline int +Model::GetNumOutputGroup() const { + return Dispatch([](const auto& handle) { return handle.num_output_group; }); +} + +inline bool +Model::GetRandomForestFlag() const { + return Dispatch([](const auto& handle) { return handle.random_forest_flag; }); +} + +inline size_t +Model::GetNumTree() const { + return Dispatch([](const auto& handle) { return handle.trees.size(); }); +} + +inline void +Model::SetTreeLimit(size_t limit) { + Dispatch([limit](auto& handle) { handle.trees.resize(limit); }); +} + +inline void +Model::ReferenceSerialize(dmlc::Stream* fo) const { + Dispatch([fo](const auto& handle) { handle.ReferenceSerialize(fo); }); +} + +inline std::vector +Model::GetPyBuffer() { + return Dispatch([](auto& handle) { return handle.GetPyBuffer(); }); +} + +inline void +Model::InitFromPyBuffer(std::vector frames) { + Dispatch([&frames](auto& handle) { handle.InitFromPyBuffer(frames); }); +} + inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg) { auto unknown = param->InitAllowUnknown(cfg); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ea2e77b5..c1fb1efa 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -50,7 +50,7 @@ endforeach() set_target_properties(objtreelite objtreelite_runtime objtreelite_common PROPERTIES POSITION_INDEPENDENT_CODE ON - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON) target_sources(objtreelite @@ -59,6 +59,7 @@ target_sources(objtreelite compiler/ast/ast.h compiler/ast/build.cc compiler/ast/builder.h + compiler/ast/builder.cc compiler/ast/dump.cc compiler/ast/fold_code.cc compiler/ast/is_categorical_array.cc @@ -80,7 +81,7 @@ target_sources(objtreelite compiler/failsafe.cc compiler/pred_transform.cc compiler/pred_transform.h - frontend/builder.cc + #frontend/builder.cc frontend/lightgbm.cc frontend/xgboost.cc annotator.cc diff --git a/src/annotator.cc b/src/annotator.cc index 36b90d22..d8a9c1b3 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -18,7 +18,8 @@ union Entry { float fvalue; }; -void Traverse_(const treelite::Tree& tree, const Entry* data, +template +void Traverse_(const treelite::Tree& tree, const Entry* data, int nid, size_t* out_counts) { ++out_counts[nid]; if (!tree.IsLeaf(nid)) { @@ -29,9 +30,9 @@ void Traverse_(const treelite::Tree& tree, const Entry* data, } else { bool result = true; if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) { - const treelite::tl_float threshold = tree.Threshold(nid); + const ThresholdType threshold = tree.Threshold(nid); const treelite::Operator op = tree.ComparisonOp(nid); - const auto fvalue = static_cast(data[split_index].fvalue); + const auto fvalue = static_cast(data[split_index].fvalue); result = treelite::CompareWithOp(fvalue, op, threshold); } else { const auto fvalue = data[split_index].fvalue; @@ -49,16 +50,17 @@ void Traverse_(const treelite::Tree& tree, const Entry* data, } } -void Traverse(const treelite::Tree& tree, const Entry* data, +template +void Traverse(const treelite::Tree& tree, const Entry* data, size_t* out_counts) { Traverse_(tree, data, 0, out_counts); } -inline void ComputeBranchLoop(const treelite::Model& model, - const treelite::DMatrix* dmat, - size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, - size_t* counts_tloc, Entry* inst) { +template +inline void ComputeBranchLoop(const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, size_t rbegin, size_t rend, + int nthread, const size_t* count_row_ptr, size_t* counts_tloc, + Entry* inst) { const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -75,8 +77,7 @@ inline void ComputeBranchLoop(const treelite::Model& model, inst[off + dmat->col_ind[i]].fvalue = dmat->data[i]; } for (size_t tree_id = 0; tree_id < ntree; ++tree_id) { - Traverse(model.trees[tree_id], &inst[off], - &counts_tloc[off2 + count_row_ptr[tree_id]]); + Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]); } for (size_t i = ibegin; i < iend; ++i) { inst[off + dmat->col_ind[i]].missing = -1; @@ -88,9 +89,10 @@ inline void ComputeBranchLoop(const treelite::Model& model, namespace treelite { +template void -BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, - int nthread, int verbose) { +BranchAnnotator::AnnotateImpl(const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, int nthread, int verbose) { std::vector new_counts; std::vector counts_tloc; std::vector count_row_ptr; @@ -98,7 +100,7 @@ BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, const size_t ntree = model.trees.size(); const int max_thread = omp_get_max_threads(); nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread); - for (const Tree& tree : model.trees) { + for (const treelite::Tree& tree : model.trees) { count_row_ptr.push_back(count_row_ptr.back() + tree.num_nodes); } new_counts.resize(count_row_ptr[ntree], 0); @@ -106,7 +108,7 @@ BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, std::vector inst(nthread * dmat->num_col, {-1}); const size_t pstep = (dmat->num_row + 19) / 20; - // interval to display progress + // interval to display progress for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) { const size_t rend = std::min(rbegin + pstep, dmat->num_row); ComputeBranchLoop(model, dmat, rbegin, rend, nthread, @@ -126,11 +128,17 @@ BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, // change layout of counts for (size_t i = 0; i < ntree; ++i) { - this->counts.emplace_back(&new_counts[count_row_ptr[i]], - &new_counts[count_row_ptr[i + 1]]); + this->counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]); } } +void +BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) { + model.Dispatch([this, dmat, nthread, verbose](auto& handle) { + AnnotateImpl(handle, dmat, nthread, verbose); + }); +} + void BranchAnnotator::Load(dmlc::Stream* fi) { dmlc::istream is(fi); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1850ef2e..5aba3bc3 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -342,35 +342,36 @@ int TreeliteFreeModel(ModelHandle handle) { int TreeliteQueryNumTree(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); - *out = model_->trees.size(); + const auto* model_ = static_cast(handle); + *out = model_->GetNumTree(); API_END(); } int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); - *out = static_cast(model_->num_feature); + const auto* model_ = static_cast(handle); + *out = static_cast(model_->GetNumFeature()); API_END(); } int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); - *out = static_cast(model_->num_output_group); + const auto* model_ = static_cast(handle); + *out = static_cast(model_->GetNumOutputGroup()); API_END(); } int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) { API_BEGIN(); CHECK_GT(limit, 0) << "limit should be greater than 0!"; - auto model_ = static_cast(handle); - CHECK_GE(model_->trees.size(), limit) - << "Model contains less trees(" << model_->trees.size() << ") than limit"; - model_->trees.resize(limit); + auto* model_ = static_cast(handle); + const size_t num_tree = model_->GetNumTree(); + CHECK_GE(num_tree, limit) << "Model contains less trees(" << num_tree << ") than limit"; + model_->SetTreeLimit(limit); API_END(); } +#if 0 int TreeliteCreateTreeBuilder(TreeBuilderHandle* out) { API_BEGIN(); std::unique_ptr builder{new frontend::TreeBuilder()}; @@ -536,3 +537,4 @@ int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, *out = static_cast(model.release()); API_END(); } +#endif diff --git a/src/compiler/ast/ast.h b/src/compiler/ast/ast.h index 6139a631..ce14606b 100644 --- a/src/compiler/ast/ast.h +++ b/src/compiler/ast/ast.h @@ -60,13 +60,14 @@ class TranslationUnitNode : public ASTNode { } }; +template class QuantizerNode : public ASTNode { public: - explicit QuantizerNode(const std::vector>& cut_pts) + explicit QuantizerNode(const std::vector>& cut_pts) : cut_pts(cut_pts) {} - explicit QuantizerNode(std::vector>&& cut_pts) + explicit QuantizerNode(std::vector>&& cut_pts) : cut_pts(std::move(cut_pts)) {} - std::vector> cut_pts; + std::vector> cut_pts; std::string GetDump() const override { std::ostringstream oss; @@ -118,29 +119,31 @@ class ConditionNode : public ASTNode { } }; +template union ThresholdVariant { - tl_float float_val; + ThresholdType float_val; int int_val; - ThresholdVariant(tl_float val) : float_val(val) {} - ThresholdVariant(int val) : int_val(val) {} + explicit ThresholdVariant(ThresholdType val) : float_val(val) {} + explicit ThresholdVariant(int val) : int_val(val) {} }; +template class NumericalConditionNode : public ConditionNode { public: NumericalConditionNode(unsigned split_index, bool default_left, bool quantized, Operator op, - ThresholdVariant threshold) + ThresholdVariant threshold) : ConditionNode(split_index, default_left), quantized(quantized), op(op), threshold(threshold) {} bool quantized; Operator op; - ThresholdVariant threshold; + ThresholdVariant threshold; std::string GetDump() const override { return fmt::format("NumericalConditionNode {{ {}, quantized: {}, op: {}, threshold: {} }}", ConditionNode::GetDump(), quantized, OpName(op), - (quantized ? fmt::format("{:d}", threshold.int_val) - : fmt::format("{:f}", threshold.float_val))); + (quantized ? fmt::format("{}", threshold.int_val) + : fmt::format("{}", threshold.float_val))); } }; @@ -168,15 +171,16 @@ class CategoricalConditionNode : public ConditionNode { } }; +template class OutputNode : public ASTNode { public: - explicit OutputNode(tl_float scalar) + explicit OutputNode(LeafOutputType scalar) : is_vector(false), scalar(scalar) {} - explicit OutputNode(const std::vector& vector) + explicit OutputNode(const std::vector& vector) : is_vector(true), vector(vector) {} bool is_vector; - tl_float scalar; - std::vector vector; + LeafOutputType scalar; + std::vector vector; std::string GetDump() const override { if (is_vector) { diff --git a/src/compiler/ast/build.cc b/src/compiler/ast/build.cc index 19244d77..dc3808c7 100644 --- a/src/compiler/ast/build.cc +++ b/src/compiler/ast/build.cc @@ -11,9 +11,11 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(build); -void ASTBuilder::BuildAST(const Model& model) { - this->output_vector_flag - = (model.num_output_group > 1 && model.random_forest_flag); +template +void +ASTBuilder::BuildAST( + const ModelImpl& model) { + this->output_vector_flag = (model.num_output_group > 1 && model.random_forest_flag); this->num_feature = model.num_feature; this->num_output_group = model.num_output_group; this->random_forest_flag = model.random_forest_flag; @@ -25,35 +27,32 @@ void ASTBuilder::BuildAST(const Model& model) { ASTNode* ac = AddNode(this->main_node); this->main_node->children.push_back(ac); for (int tree_id = 0; tree_id < model.trees.size(); ++tree_id) { - ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, ac); + ASTNode* tree_head = BuildASTFromTree(model.trees[tree_id], tree_id, 0, ac); ac->children.push_back(tree_head); } this->model_param = model.param.__DICT__(); } -ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id, - ASTNode* parent) { - return BuildASTFromTree(tree, tree_id, 0, parent); -} - -ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id, int nid, - ASTNode* parent) { +template +ASTNode* +ASTBuilder::BuildASTFromTree( + const Tree& tree, int tree_id, int nid, ASTNode* parent) { ASTNode* ast_node = nullptr; if (tree.IsLeaf(nid)) { if (this->output_vector_flag) { - ast_node = AddNode(parent, tree.LeafVector(nid)); + ast_node = AddNode>(parent, tree.LeafVector(nid)); } else { - ast_node = AddNode(parent, tree.LeafValue(nid)); + ast_node = AddNode>(parent, tree.LeafValue(nid)); } } else { if (tree.SplitType(nid) == SplitFeatureType::kNumerical) { - ast_node = AddNode(parent, - tree.SplitIndex(nid), - tree.DefaultLeft(nid), - false, - tree.ComparisonOp(nid), - ThresholdVariant(static_cast( - tree.Threshold(nid)))); + ast_node = AddNode>( + parent, + tree.SplitIndex(nid), + tree.DefaultLeft(nid), + false, + tree.ComparisonOp(nid), + ThresholdVariant(tree.Threshold(nid))); } else { ast_node = AddNode(parent, tree.SplitIndex(nid), @@ -79,5 +78,19 @@ ASTNode* ASTBuilder::BuildASTFromTree(const Tree& tree, int tree_id, int nid, return ast_node; } + +template void ASTBuilder::BuildAST(const ModelImpl&); +template void ASTBuilder::BuildAST(const ModelImpl&); +template void ASTBuilder::BuildAST(const ModelImpl&); +template void ASTBuilder::BuildAST(const ModelImpl&); +template ASTNode* ASTBuilder::BuildASTFromTree( + const Tree&, int, int, ASTNode*); +template ASTNode* ASTBuilder::BuildASTFromTree( + const Tree&, int, int, ASTNode*); +template ASTNode* ASTBuilder::BuildASTFromTree( + const Tree&, int, int, ASTNode*); +template ASTNode* ASTBuilder::BuildASTFromTree( + const Tree&, int, int, ASTNode*); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/builder.cc b/src/compiler/ast/builder.cc new file mode 100644 index 00000000..45e19374 --- /dev/null +++ b/src/compiler/ast/builder.cc @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2017-2020 by Contributors + * \file builder.cc + * \brief Explicit template specializations for the ASTBuilder class + * \author Hyunsu Cho + */ + +#include "./builder.h" + +namespace treelite { +namespace compiler { + +// Explicit template specializations +// (https://docs.microsoft.com/en-us/cpp/cpp/source-code-organization-cpp-templates) +template class ASTBuilder; +template class ASTBuilder; +template class ASTBuilder; +template class ASTBuilder; + +} // namespace compiler +} // namespace treelite diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index 8046b15d..00a75385 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -18,19 +18,21 @@ namespace treelite { namespace compiler { -// forward declaration +// forward declarations +template class ASTBuilder; struct CodeFoldingContext; -bool fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*); -bool breakup(ASTNode*, int, int*, ASTBuilder*); +template +bool fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*); +template class ASTBuilder { public: ASTBuilder() : output_vector_flag(false), main_node(nullptr), quantize_threshold_flag(false) {} /* \brief initially build AST from model */ - void BuildAST(const Model& model); + void BuildAST(const ModelImpl& model); /* \brief generate is_categorical[] array, which tells whether each feature is categorical or numerical */ std::vector GenerateIsCategoricalArray(); @@ -68,8 +70,7 @@ class ASTBuilder { } private: - friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*, - ASTBuilder*); + friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*); template NodeType* AddNode(ASTNode* parent, Args&& ...args) { @@ -79,8 +80,8 @@ class ASTBuilder { nodes.push_back(std::move(node)); return ref; } - ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, ASTNode* parent); - ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, int nid, + + ASTNode* BuildASTFromTree(const Tree& tree, int tree_id, int nid, ASTNode* parent); // keep tract of all nodes built so far, to prevent memory leak diff --git a/src/compiler/ast/dump.cc b/src/compiler/ast/dump.cc index 89c15dcc..00ed2496 100644 --- a/src/compiler/ast/dump.cc +++ b/src/compiler/ast/dump.cc @@ -25,11 +25,18 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(dump); -std::string ASTBuilder::GetDump() const { +template +std::string +ASTBuilder::GetDump() const { std::ostringstream oss; get_dump_from_node(&oss, this->main_node, 0); return oss.str(); } +template std::string ASTBuilder::GetDump() const; +template std::string ASTBuilder::GetDump() const; +template std::string ASTBuilder::GetDump() const; +template std::string ASTBuilder::GetDump() const; + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/fold_code.cc b/src/compiler/ast/fold_code.cc index c70cca8e..b9e076f2 100644 --- a/src/compiler/ast/fold_code.cc +++ b/src/compiler/ast/fold_code.cc @@ -22,8 +22,9 @@ struct CodeFoldingContext { int num_tu; }; +template bool fold_code(ASTNode* node, CodeFoldingContext* context, - ASTBuilder* builder) { + ASTBuilder* builder) { if (node->node_id == 0) { if (node->data_count) { context->log_root_data_count = std::log(node->data_count.value()); @@ -48,14 +49,13 @@ bool fold_code(ASTNode* node, CodeFoldingContext* context, ASTNode* folder_node = nullptr; ASTNode* tu_node = nullptr; if (context->create_new_translation_unit) { - tu_node - = builder->AddNode(parent_node, context->num_tu++); - ASTNode* ac = builder->AddNode(tu_node); - folder_node = builder->AddNode(ac); + tu_node = builder->template AddNode(parent_node, context->num_tu++); + ASTNode* ac = builder->template AddNode(tu_node); + folder_node = builder->template AddNode(ac); tu_node->children.push_back(ac); ac->children.push_back(folder_node); } else { - folder_node = builder->AddNode(parent_node); + folder_node = builder->template AddNode(parent_node); } size_t node_loc = -1; // is current node 1st child or 2nd child or so forth for (size_t i = 0; i < parent_node->children.size(); ++i) { @@ -81,8 +81,10 @@ bool fold_code(ASTNode* node, CodeFoldingContext* context, int count_tu_nodes(ASTNode* node); -bool ASTBuilder::FoldCode(double magnitude_req, - bool create_new_translation_unit) { +template +bool +ASTBuilder::FoldCode( + double magnitude_req, bool create_new_translation_unit) { CodeFoldingContext context{magnitude_req, std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), @@ -91,5 +93,10 @@ bool ASTBuilder::FoldCode(double magnitude_req, return fold_code(this->main_node, &context, this); } +template bool ASTBuilder::FoldCode(double, bool); +template bool ASTBuilder::FoldCode(double, bool); +template bool ASTBuilder::FoldCode(double, bool); +template bool ASTBuilder::FoldCode(double, bool); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/is_categorical_array.cc b/src/compiler/ast/is_categorical_array.cc index c940cbd3..38df8488 100644 --- a/src/compiler/ast/is_categorical_array.cc +++ b/src/compiler/ast/is_categorical_array.cc @@ -24,11 +24,18 @@ scan_thresholds(ASTNode* node, std::vector* is_categorical) { } } -std::vector ASTBuilder::GenerateIsCategoricalArray() { +template +std::vector +ASTBuilder::GenerateIsCategoricalArray() { this->is_categorical = std::vector(this->num_feature, false); scan_thresholds(this->main_node, &this->is_categorical); return this->is_categorical; } +template std::vector ASTBuilder::GenerateIsCategoricalArray(); +template std::vector ASTBuilder::GenerateIsCategoricalArray(); +template std::vector ASTBuilder::GenerateIsCategoricalArray(); +template std::vector ASTBuilder::GenerateIsCategoricalArray(); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/load_data_counts.cc b/src/compiler/ast/load_data_counts.cc index 8aceeb05..23693de2 100644 --- a/src/compiler/ast/load_data_counts.cc +++ b/src/compiler/ast/load_data_counts.cc @@ -13,8 +13,7 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(load_data_counts); -static void load_data_counts(ASTNode* node, - const std::vector>& counts) { +static void load_data_counts(ASTNode* node, const std::vector>& counts) { if (node->tree_id >= 0 && node->node_id >= 0) { node->data_count = counts[node->tree_id][node->node_id]; } @@ -23,10 +22,17 @@ static void load_data_counts(ASTNode* node, } } +template void -ASTBuilder::LoadDataCounts(const std::vector>& counts) { +ASTBuilder::LoadDataCounts( + const std::vector>& counts) { load_data_counts(this->main_node, counts); } +template void ASTBuilder::LoadDataCounts(const std::vector>&); +template void ASTBuilder::LoadDataCounts(const std::vector>&); +template void ASTBuilder::LoadDataCounts(const std::vector>&); +template void ASTBuilder::LoadDataCounts(const std::vector>&); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/quantize.cc b/src/compiler/ast/quantize.cc index 53f334df..cb4c9538 100644 --- a/src/compiler/ast/quantize.cc +++ b/src/compiler/ast/quantize.cc @@ -13,14 +13,14 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(quantize); +template static void -scan_thresholds(ASTNode* node, - std::vector>* cut_pts) { - NumericalConditionNode* num_cond; +scan_thresholds(ASTNode* node, std::vector>* cut_pts) { + NumericalConditionNode* num_cond; CategoricalConditionNode* cat_cond; - if ( (num_cond = dynamic_cast(node)) ) { + if ( (num_cond = dynamic_cast*>(node)) ) { CHECK(!num_cond->quantized) << "should not be already quantized"; - const tl_float threshold = num_cond->threshold.float_val; + const ThresholdType threshold = num_cond->threshold.float_val; if (std::isfinite(threshold)) { (*cut_pts)[num_cond->split_index].insert(threshold); } @@ -30,13 +30,13 @@ scan_thresholds(ASTNode* node, } } +template static void -rewrite_thresholds(ASTNode* node, - const std::vector>& cut_pts) { - NumericalConditionNode* num_cond; - if ( (num_cond = dynamic_cast(node)) ) { +rewrite_thresholds(ASTNode* node, const std::vector>& cut_pts) { + NumericalConditionNode* num_cond; + if ( (num_cond = dynamic_cast*>(node)) ) { CHECK(!num_cond->quantized) << "should not be already quantized"; - const tl_float threshold = num_cond->threshold.float_val; + const ThresholdType threshold = num_cond->threshold.float_val; if (std::isfinite(threshold)) { const auto& v = cut_pts[num_cond->split_index]; auto loc = math::binary_search(v.begin(), v.end(), threshold); @@ -50,17 +50,18 @@ rewrite_thresholds(ASTNode* node, } } -void ASTBuilder::QuantizeThresholds() { +template +void +ASTBuilder::QuantizeThresholds() { this->quantize_threshold_flag = true; - std::vector> cut_pts; - std::vector> cut_pts_vec; + std::vector> cut_pts; + std::vector> cut_pts_vec; cut_pts.resize(this->num_feature); cut_pts_vec.resize(this->num_feature); scan_thresholds(this->main_node, &cut_pts); // convert cut_pts into std::vector for (int i = 0; i < this->num_feature; ++i) { - std::copy(cut_pts[i].begin(), cut_pts[i].end(), - std::back_inserter(cut_pts_vec[i])); + std::copy(cut_pts[i].begin(), cut_pts[i].end(), std::back_inserter(cut_pts_vec[i])); } /* revise all numerical splits by quantizing thresholds */ @@ -72,12 +73,17 @@ void ASTBuilder::QuantizeThresholds() { /* dynamic_cast<> is used here to check node types. This is to ensure that we don't accidentally call QuantizeThresholds() twice. */ - ASTNode* quantizer_node = AddNode(this->main_node, - std::move(cut_pts_vec)); + ASTNode* quantizer_node + = AddNode>(this->main_node, std::move(cut_pts_vec)); quantizer_node->children.push_back(top_ac_node); top_ac_node->parent = quantizer_node; this->main_node->children[0] = quantizer_node; } +template void ASTBuilder::QuantizeThresholds(); +template void ASTBuilder::QuantizeThresholds(); +template void ASTBuilder::QuantizeThresholds(); +template void ASTBuilder::QuantizeThresholds(); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast/split.cc b/src/compiler/ast/split.cc index db85901c..9bdb557b 100644 --- a/src/compiler/ast/split.cc +++ b/src/compiler/ast/split.cc @@ -19,7 +19,9 @@ int count_tu_nodes(ASTNode* node) { return accum; } -void ASTBuilder::Split(int parallel_comp) { +template +void +ASTBuilder::Split(int parallel_comp) { if (parallel_comp <= 0) { LOG(INFO) << "Parallel compilation disabled; all member trees will be " << "dumped to a single source file. This may increase " @@ -35,7 +37,7 @@ void ASTBuilder::Split(int parallel_comp) { /* tree_head[i] stores reference to head of tree i */ std::vector tree_head; for (ASTNode* node : top_ac_node->children) { - CHECK(dynamic_cast(node) || dynamic_cast(node) + CHECK(dynamic_cast(node) || dynamic_cast*>(node) || dynamic_cast(node)); tree_head.push_back(node); } @@ -66,5 +68,10 @@ void ASTBuilder::Split(int parallel_comp) { top_ac_node->children = tu_list; } +template void ASTBuilder::Split(int); +template void ASTBuilder::Split(int); +template void ASTBuilder::Split(int); +template void ASTBuilder::Split(int); + } // namespace compiler } // namespace treelite diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index 4a4175e2..1ebffb2a 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -49,7 +49,8 @@ class ASTNativeCompiler : public Compiler { } } - CompiledModel Compile(const Model& model) override { + template + CompiledModel CompileImpl(const ModelImpl& model) { CompiledModel cm; cm.backend = "native"; @@ -58,10 +59,9 @@ class ASTNativeCompiler : public Compiler { pred_transform_ = model.param.pred_transform; sigmoid_alpha_ = model.param.sigmoid_alpha; global_bias_ = model.param.global_bias; - pred_tranform_func_ = PredTransformFunction("native", model); files_.clear(); - ASTBuilder builder; + ASTBuilder builder; builder.BuildAST(model); if (builder.FoldCode(param.code_folding_req) || param.quantize > 0) { @@ -92,7 +92,7 @@ class ASTNativeCompiler : public Compiler { } } - WalkAST(builder.GetRootNode(), "main.c", 0); + WalkAST(builder.GetRootNode(), "main.c", 0); if (files_.count("arrays.c") > 0) { PrependToBuffer("arrays.c", "#include \"header.h\"\n", 0); } @@ -121,6 +121,13 @@ class ASTNativeCompiler : public Compiler { return cm; } + CompiledModel Compile(const Model& model) override { + this->pred_tranform_func_ = PredTransformFunction("native", model); + return model.Dispatch([this](const auto& model_handle) { + return CompileImpl(model_handle); + }); + } + private: CompilerParam param; int num_feature_; @@ -132,30 +139,31 @@ class ASTNativeCompiler : public Compiler { std::string array_is_categorical_; std::unordered_map files_; + template void WalkAST(const ASTNode* node, const std::string& dest, size_t indent) { const MainNode* t1; const AccumulatorContextNode* t2; const ConditionNode* t3; - const OutputNode* t4; + const OutputNode* t4; const TranslationUnitNode* t5; - const QuantizerNode* t6; + const QuantizerNode* t6; const CodeFolderNode* t7; if ( (t1 = dynamic_cast(node)) ) { - HandleMainNode(t1, dest, indent); + HandleMainNode(t1, dest, indent); } else if ( (t2 = dynamic_cast(node)) ) { - HandleACNode(t2, dest, indent); + HandleACNode(t2, dest, indent); } else if ( (t3 = dynamic_cast(node)) ) { - HandleCondNode(t3, dest, indent); - } else if ( (t4 = dynamic_cast(node)) ) { - HandleOutputNode(t4, dest, indent); + HandleCondNode(t3, dest, indent); + } else if ( (t4 = dynamic_cast*>(node)) ) { + HandleOutputNode(t4, dest, indent); } else if ( (t5 = dynamic_cast(node)) ) { - HandleTUNode(t5, dest, indent); - } else if ( (t6 = dynamic_cast(node)) ) { - HandleQNode(t6, dest, indent); + HandleTUNode(t5, dest, indent); + } else if ( (t6 = dynamic_cast*>(node)) ) { + HandleQNode(t6, dest, indent); } else if ( (t7 = dynamic_cast(node)) ) { - HandleCodeFolderNode(t7, dest, indent); + HandleCodeFolderNode(t7, dest, indent); } else { LOG(FATAL) << "Unrecognized AST node type"; } @@ -176,6 +184,7 @@ class ASTNativeCompiler : public Compiler { = common_util::IndentMultiLineString(content, indent) + files_[dest].content; } + template void HandleMainNode(const MainNode* node, const std::string& dest, size_t indent) { @@ -240,7 +249,7 @@ class ASTNativeCompiler : public Compiler { indent); CHECK_EQ(node->children.size(), 1); - WalkAST(node->children[0], dest, indent + 2); + WalkAST(node->children[0], dest, indent + 2); const std::string optional_average_field = (node->average_result) ? fmt::format(" / {}", node->num_tree) @@ -261,6 +270,7 @@ class ASTNativeCompiler : public Compiler { } } + template void HandleACNode(const AccumulatorContextNode* node, const std::string& dest, size_t indent) { @@ -277,16 +287,17 @@ class ASTNativeCompiler : public Compiler { "int nid, cond, fid; /* used for folded subtrees */\n", indent); } for (ASTNode* child : node->children) { - WalkAST(child, dest, indent); + WalkAST(child, dest, indent); } } + template void HandleCondNode(const ConditionNode* node, const std::string& dest, size_t indent) { - const NumericalConditionNode* t; + const NumericalConditionNode* t; std::string condition, condition_with_na_check; - if ( (t = dynamic_cast(node)) ) { + if ( (t = dynamic_cast*>(node)) ) { /* numerical split */ condition = ExtractNumericalCondition(t); const char* condition_with_na_check_template @@ -298,8 +309,7 @@ class ASTNativeCompiler : public Compiler { "split_index"_a = node->split_index, "condition"_a = condition); } else { /* categorical split */ - const CategoricalConditionNode* t2 - = dynamic_cast(node); + const CategoricalConditionNode* t2 = dynamic_cast(node); CHECK(t2); condition_with_na_check = ExtractCategoricalCondition(t2); } @@ -314,19 +324,21 @@ class ASTNativeCompiler : public Compiler { AppendToBuffer(dest, fmt::format("if ({}) {{\n", condition_with_na_check), indent); CHECK_EQ(node->children.size(), 2); - WalkAST(node->children[0], dest, indent + 2); + WalkAST(node->children[0], dest, indent + 2); AppendToBuffer(dest, "} else {\n", indent); - WalkAST(node->children[1], dest, indent + 2); + WalkAST(node->children[1], dest, indent + 2); AppendToBuffer(dest, "}\n", indent); } - void HandleOutputNode(const OutputNode* node, + template + void HandleOutputNode(const OutputNode* node, const std::string& dest, size_t indent) { AppendToBuffer(dest, RenderOutputStatement(node), indent); CHECK_EQ(node->children.size(), 0); } + template void HandleTUNode(const TranslationUnitNode* node, const std::string& dest, int indent) { @@ -356,7 +368,7 @@ class ASTNativeCompiler : public Compiler { fmt::format("#include \"header.h\"\n" "{} {{\n", unit_function_signature), 0); CHECK_EQ(node->children.size(), 1); - WalkAST(node->children[0], new_file, 2); + WalkAST(node->children[0], new_file, 2); if (num_output_group_ > 1) { AppendToBuffer(new_file, fmt::format(" for (int i = 0; i < {num_output_group}; ++i) {{\n" @@ -370,7 +382,8 @@ class ASTNativeCompiler : public Compiler { AppendToBuffer("header.h", fmt::format("{};\n", unit_function_signature), 0); } - void HandleQNode(const QuantizerNode* node, + template + void HandleQNode(const QuantizerNode* node, const std::string& dest, size_t indent) { /* render arrays needed to convert feature values into bin indices */ @@ -386,7 +399,7 @@ class ASTNativeCompiler : public Compiler { for (const auto& e : node->cut_pts) { // cut_pts had been generated in ASTBuilder::QuantizeThresholds // cut_pts[i][k] stores the k-th threshold of feature i. - for (tl_float v : e) { + for (auto v : e) { formatter << v; } } @@ -436,9 +449,10 @@ class ASTNativeCompiler : public Compiler { "}};\n", "array_th_len"_a = array_th_len), 0); } CHECK_EQ(node->children.size(), 1); - WalkAST(node->children[0], dest, indent); + WalkAST(node->children[0], dest, indent); } + template void HandleCodeFolderNode(const CodeFolderNode* node, const std::string& dest, size_t indent) { @@ -464,9 +478,9 @@ class ASTNativeCompiler : public Compiler { std::string output_switch_statement; Operator common_comp_op; - common_util::RenderCodeFolderArrays(node, param.quantize, false, + common_util::RenderCodeFolderArrays(node, param.quantize, false, "{{ {default_left}, {split_index}, {threshold}, {left_child}, {right_child} }}", - [this](const OutputNode* node) { return RenderOutputStatement(node); }, + [this](const OutputNode* node) { return RenderOutputStatement(node); }, &array_nodes, &array_cat_bitmap, &array_cat_begin, &output_switch_statement, &common_comp_op); if (!array_nodes.empty()) { @@ -533,8 +547,9 @@ class ASTNativeCompiler : public Compiler { } } + template inline std::string - ExtractNumericalCondition(const NumericalConditionNode* node) { + ExtractNumericalCondition(const NumericalConditionNode* node) { std::string result; if (node->quantized) { // quantized threshold result = fmt::format("data[{split_index}].qvalue {opname} {threshold}", @@ -544,7 +559,8 @@ class ASTNativeCompiler : public Compiler { } else if (std::isinf(node->threshold.float_val)) { // infinite threshold // According to IEEE 754, the result of comparison [lhs] < infinity // must be identical for all finite [lhs]. Same goes for operator >. - result = (CompareWithOp(0.0, node->op, node->threshold.float_val) ? "1" : "0"); + result = (CompareWithOp(static_cast(0), node->op, node->threshold.float_val) + ? "1" : "0"); } else { // finite threshold result = fmt::format("data[{split_index}].fvalue {opname} (float){threshold}", "split_index"_a = node->split_index, @@ -607,7 +623,8 @@ class ASTNativeCompiler : public Compiler { return formatter.str(); } - inline std::string RenderOutputStatement(const OutputNode* node) { + template + inline std::string RenderOutputStatement(const OutputNode* node) { std::string output_statement; if (num_output_group_ > 1) { if (node->is_vector) { diff --git a/src/compiler/common/code_folding_util.h b/src/compiler/common/code_folding_util.h index 54685de7..0913ddcd 100644 --- a/src/compiler/common/code_folding_util.h +++ b/src/compiler/common/code_folding_util.h @@ -24,7 +24,8 @@ namespace treelite { namespace compiler { namespace common_util { -template + +template inline void RenderCodeFolderArrays(const CodeFolderNode* node, bool quantize, @@ -41,7 +42,7 @@ RenderCodeFolderArrays(const CodeFolderNode* node, // list of descendants, with newly assigned ID's std::unordered_map descendants; // list of all OutputNode's among the descendants - std::vector output_nodes; + std::vector*> output_nodes; // two arrays used to store categorical split info std::vector cat_bitmap; std::vector cat_begin{0}; @@ -60,13 +61,13 @@ RenderCodeFolderArrays(const CodeFolderNode* node, CHECK_EQ(e->tree_id, tree_id); // sanity check: all descendants must be ConditionNode or OutputNode ConditionNode* t1 = dynamic_cast(e); - OutputNode* t2 = dynamic_cast(e); - NumericalConditionNode* t3; + OutputNode* t2 = dynamic_cast*>(e); + NumericalConditionNode* t3; CHECK(t1 || t2); if (t2) { // e is OutputNode descendants[e] = new_leaf_id--; } else { - if ( (t3 = dynamic_cast(t1)) ) { + if ( (t3 = dynamic_cast*>(t1)) ) { ops.insert(t3->op); } descendants[e] = new_node_id++; @@ -89,22 +90,22 @@ RenderCodeFolderArrays(const CodeFolderNode* node, std::string threshold; int left_child_id, right_child_id; unsigned int split_index; - OutputNode* t1; - NumericalConditionNode* t2; + OutputNode* t1; + NumericalConditionNode* t2; CategoricalConditionNode* t3; std::queue Q; Q.push(node->children[0]); while (!Q.empty()) { ASTNode* e = Q.front(); Q.pop(); - if ( (t1 = dynamic_cast(e)) ) { + if ( (t1 = dynamic_cast*>(e)) ) { output_nodes.push_back(t1); // don't render OutputNode but save it for later } else { CHECK_EQ(e->children.size(), 2U); left_child_id = descendants[ e->children[0] ]; right_child_id = descendants[ e->children[1] ]; - if ( (t2 = dynamic_cast(e)) ) { + if ( (t2 = dynamic_cast*>(e)) ) { default_left = t2->default_left; split_index = t2->split_index; threshold @@ -165,7 +166,7 @@ RenderCodeFolderArrays(const CodeFolderNode* node, } // 4. Render switch statement to associate each node ID with an output *output_switch_statements = "switch (nid) {\n"; - for (OutputNode* e : output_nodes) { + for (OutputNode* e : output_nodes) { const int node_id = descendants[static_cast(e)]; *output_switch_statements += fmt::format(" case {node_id}:\n" diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index 2ab16bcd..c09b88e1 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -136,12 +136,13 @@ const char* arrays_template = R"TREELITETEMPLATE( // nodes[]: stores nodes from all decision trees // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]] -inline std::pair FormatNodesArray(const treelite::Model& model) { +inline std::pair FormatNodesArray( + const treelite::ModelImpl& model_handle) { treelite::compiler::common_util::ArrayFormatter nodes(100, 2); treelite::compiler::common_util::ArrayFormatter nodes_row_ptr(100, 2); int node_count = 0; nodes_row_ptr << "0"; - for (const auto& tree : model.trees) { + for (const auto& tree : model_handle.trees) { for (int nid = 0; nid < tree.num_nodes; ++nid) { if (tree.IsLeaf(nid)) { CHECK(!tree.HasLeafVector(nid)) @@ -172,7 +173,8 @@ inline std::pair FormatNodesArray(const treelite::Mode } // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary -inline std::pair, std::string> FormatNodesArrayELF(const treelite::Model& model) { +inline std::pair, std::string> FormatNodesArrayELF( + const treelite::ModelImpl& model_handle) { std::vector nodes_elf; treelite::compiler::AllocateELFHeader(&nodes_elf); @@ -180,7 +182,7 @@ inline std::pair, std::string> FormatNodesArrayELF(const treel NodeStructValue val; int node_count = 0; nodes_row_ptr << "0"; - for (const auto& tree : model.trees) { + for (const auto& tree : model_handle.trees) { for (int nid = 0; nid < tree.num_nodes; ++nid) { if (tree.IsLeaf(nid)) { CHECK(!tree.HasLeafVector(nid)) @@ -208,7 +210,7 @@ inline std::pair, std::string> FormatNodesArrayELF(const treel // Get the comparison op used in the tree ensemble model // If splits have more than one op, throw an error -inline std::string GetCommonOp(const treelite::Model& model) { +inline std::string GetCommonOp(const treelite::ModelImpl& model) { std::set ops; for (const auto& tree : model.trees) { for (int nid = 0; nid < tree.num_nodes; ++nid) { @@ -263,12 +265,16 @@ class FailSafeCompiler : public Compiler { } CompiledModel Compile(const Model& model) override { + CHECK(model.GetModelType() == ModelType::kFloat32ThresholdFloat32LeafOutput) + << "Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs"; + const ModelImpl& model_handle = model.GetImpl(); + CompiledModel cm; cm.backend = "native"; - num_feature_ = model.num_feature; - num_output_group_ = model.num_output_group; - CHECK(!model.random_forest_flag) + num_feature_ = model_handle.num_feature; + num_output_group_ = model_handle.num_output_group; + CHECK(!model_handle.random_forest_flag) << "Only gradient boosted trees supported in FailSafeCompiler"; pred_tranform_func_ = PredTransformFunction("native", model); files_.clear(); @@ -297,10 +303,10 @@ class FailSafeCompiler : public Compiler { ? fmt::format(return_multiclass_template, "num_output_group"_a = num_output_group_, "global_bias"_a - = compiler::common_util::ToStringHighPrecision(model.param.global_bias)) + = compiler::common_util::ToStringHighPrecision(model_handle.param.global_bias)) : fmt::format(return_template, "global_bias"_a - = compiler::common_util::ToStringHighPrecision(model.param.global_bias))); + = compiler::common_util::ToStringHighPrecision(model_handle.param.global_bias))); std::string nodes, nodes_row_ptr; std::vector nodes_elf; @@ -308,9 +314,9 @@ class FailSafeCompiler : public Compiler { if (param.verbose > 0) { LOG(INFO) << "Dumping arrays as an ELF relocatable object..."; } - std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model); + std::tie(nodes_elf, nodes_row_ptr) = FormatNodesArrayELF(model_handle); } else { - std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model); + std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model_handle); } main_program << fmt::format(main_template, @@ -319,8 +325,8 @@ class FailSafeCompiler : public Compiler { "predict_function_signature"_a = predict_function_signature, "num_output_group"_a = num_output_group_, "num_feature"_a = num_feature_, - "num_tree"_a = model.trees.size(), - "compare_op"_a = GetCommonOp(model), + "num_tree"_a = model_handle.trees.size(), + "compare_op"_a = GetCommonOp(model_handle), "accumulator_definition"_a = accumulator_definition, "output_statement"_a = output_statement, "return_statement"_a = return_statement); diff --git a/src/compiler/native/pred_transform.h b/src/compiler/native/pred_transform.h index 9703b044..06dc1d2e 100644 --- a/src/compiler/native/pred_transform.h +++ b/src/compiler/native/pred_transform.h @@ -27,7 +27,7 @@ R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ } inline std::string sigmoid(const Model& model) { - const float alpha = model.param.sigmoid_alpha; + const float alpha = model.GetParam().sigmoid_alpha; CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive"; return fmt::format( R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ @@ -52,17 +52,17 @@ R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ } inline std::string identity_multiclass(const Model& model) { - CHECK(model.num_output_group > 1) + CHECK(model.GetNumOutputGroup() > 1) << "identity_multiclass: model is not a proper multi-class classifier"; return fmt::format( R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ return {num_class}; }})TREELITETEMPLATE", - "num_class"_a = model.num_output_group); + "num_class"_a = model.GetNumOutputGroup()); } inline std::string max_index(const Model& model) { - CHECK(model.num_output_group > 1) + CHECK(model.GetNumOutputGroup() > 1) << "max_index: model is not a proper multi-class classifier"; return fmt::format( R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ @@ -78,11 +78,11 @@ R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ pred[0] = (float)max_index; return 1; }})TREELITETEMPLATE", - "num_class"_a = model.num_output_group); + "num_class"_a = model.GetNumOutputGroup()); } inline std::string softmax(const Model& model) { - CHECK(model.num_output_group > 1) + CHECK(model.GetNumOutputGroup() > 1) << "softmax: model is not a proper multi-class classifier"; return fmt::format( R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ @@ -105,14 +105,14 @@ R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ }} return (size_t)num_class; }})TREELITETEMPLATE", - "num_class"_a = model.num_output_group); + "num_class"_a = model.GetNumOutputGroup()); } inline std::string multiclass_ova(const Model& model) { - CHECK(model.num_output_group > 1) + CHECK(model.GetNumOutputGroup() > 1) << "multiclass_ova: model is not a proper multi-class classifier"; - const int num_class = model.num_output_group; - const float alpha = model.param.sigmoid_alpha; + const int num_class = model.GetNumOutputGroup(); + const float alpha = model.GetParam().sigmoid_alpha; CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive"; return fmt::format( R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ @@ -123,7 +123,7 @@ R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ }} return (size_t)num_class; }})TREELITETEMPLATE", - "num_class"_a = model.num_output_group, "alpha"_a = alpha); + "num_class"_a = model.GetNumOutputGroup(), "alpha"_a = alpha); } } // namespace pred_transform diff --git a/src/compiler/pred_transform.cc b/src/compiler/pred_transform.cc index fafd31bd..f05ef5f4 100644 --- a/src/compiler/pred_transform.cc +++ b/src/compiler/pred_transform.cc @@ -95,8 +95,9 @@ pred_transform_multiclass_db = { std::string treelite::compiler::PredTransformFunction(const std::string& backend, const Model& model) { - if (model.num_output_group > 1) { // multi-class classification - auto it = pred_transform_multiclass_db.find(model.param.pred_transform); + ModelParam param = model.GetParam(); + if (model.GetNumOutputGroup() > 1) { // multi-class classification + auto it = pred_transform_multiclass_db.find(param.pred_transform); if (it == pred_transform_multiclass_db.end()) { std::ostringstream oss; for (const auto& e : pred_transform_multiclass_db) { @@ -109,7 +110,7 @@ treelite::compiler::PredTransformFunction(const std::string& backend, } return (it->second)(backend, model); } else { - auto it = pred_transform_db.find(model.param.pred_transform); + auto it = pred_transform_db.find(param.pred_transform); if (it == pred_transform_db.end()) { std::ostringstream oss; for (const auto& e : pred_transform_db) { diff --git a/src/compiler/pred_transform.h b/src/compiler/pred_transform.h index cc614095..69513e2c 100644 --- a/src/compiler/pred_transform.h +++ b/src/compiler/pred_transform.h @@ -14,8 +14,7 @@ namespace treelite { namespace compiler { -std::string PredTransformFunction(const std::string& backend, - const Model& model); +std::string PredTransformFunction(const std::string& backend, const Model& model); } // namespace compiler } // namespace treelite diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index f907ef20..99e75e7e 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -436,17 +436,18 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } /* 2. Export model */ - treelite::Model model; - model.num_feature = max_feature_idx_ + 1; - model.num_output_group = num_tree_per_iteration_; - if (model.num_output_group > 1) { + treelite::Model model_wrapper = treelite::Model::Create(); + treelite::ModelImpl& model_handle = model_wrapper.GetImpl(); + model_handle.num_feature = max_feature_idx_ + 1; + model_handle.num_output_group = num_tree_per_iteration_; + if (model_handle.num_output_group > 1) { // multiclass classification with gradient boosted trees CHECK(!average_output_) << "Ill-formed LightGBM model file: cannot use random forest mode " << "for multi-class classification"; - model.random_forest_flag = false; + model_handle.random_forest_flag = false; } else { - model.random_forest_flag = average_output_; + model_handle.random_forest_flag = average_output_; } // set correct prediction transform function, depending on objective function @@ -462,10 +463,10 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { break; } } - CHECK(num_class >= 0 && num_class == model.num_output_group) + CHECK(num_class >= 0 && num_class == model_handle.num_output_group) << "Ill-formed LightGBM model file: not a valid multiclass objective"; - std::strncpy(model.param.pred_transform, "softmax", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "softmax", sizeof(model_handle.param.pred_transform)); } else if (obj_name_ == "multiclassova") { // validate num_class and alpha parameters int num_class = -1; @@ -484,12 +485,12 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - CHECK(num_class >= 0 && num_class == model.num_output_group + CHECK(num_class >= 0 && num_class == model_handle.num_output_group && alpha > 0.0f) << "Ill-formed LightGBM model file: not a valid multiclassova objective"; - std::strncpy(model.param.pred_transform, "multiclass_ova", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = alpha; + std::strncpy(model_handle.param.pred_transform, "multiclass_ova", sizeof(model_handle.param.pred_transform)); + model_handle.param.sigmoid_alpha = alpha; } else if (obj_name_ == "binary") { // validate alpha parameter float alpha = -1.0f; @@ -505,22 +506,22 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { CHECK_GT(alpha, 0.0f) << "Ill-formed LightGBM model file: not a valid binary objective"; - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = alpha; + std::strncpy(model_handle.param.pred_transform, "sigmoid", sizeof(model_handle.param.pred_transform)); + model_handle.param.sigmoid_alpha = alpha; } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") { - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = 1.0f; + std::strncpy(model_handle.param.pred_transform, "sigmoid", sizeof(model_handle.param.pred_transform)); + model_handle.param.sigmoid_alpha = 1.0f; } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") { - std::strncpy(model.param.pred_transform, "logarithm_one_plus_exp", - sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "logarithm_one_plus_exp", + sizeof(model_handle.param.pred_transform)); } else { - std::strncpy(model.param.pred_transform, "identity", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "identity", sizeof(model_handle.param.pred_transform)); } // traverse trees for (const auto& lgb_tree : lgb_trees_) { - model.trees.emplace_back(); - treelite::Tree& tree = model.trees.back(); + model_handle.trees.emplace_back(); + treelite::Tree& tree = model_handle.trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -539,15 +540,14 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { std::tie(old_id, new_id) = Q.front(); Q.pop(); if (old_id < 0) { // leaf const double leaf_value = lgb_tree.leaf_value[~old_id]; - tree.SetLeaf(new_id, static_cast(leaf_value)); + tree.SetLeaf(new_id, static_cast(leaf_value)); if (!lgb_tree.leaf_count.empty()) { const int data_count = lgb_tree.leaf_count[~old_id]; CHECK_GE(data_count, 0); tree.SetDataCount(new_id, static_cast(data_count)); } } else { // non-leaf - const auto split_index = - static_cast(lgb_tree.split_feature[old_id]); + const auto split_index = static_cast(lgb_tree.split_feature[old_id]); tree.AddChilds(new_id); if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) { @@ -564,7 +564,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { left_categories); } else { // numerical - const auto threshold = static_cast(lgb_tree.threshold[old_id]); + const auto threshold = static_cast(lgb_tree.threshold[old_id]); const bool default_left = GetDecisionType(lgb_tree.decision_type[old_id], kDefaultLeftMask); const treelite::Operator cmp_op = treelite::Operator::kLE; @@ -583,8 +583,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - LOG(INFO) << "model.num_tree = " << model.trees.size(); - return model; + return model_wrapper; } } // anonymous namespace diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc index aec5679b..2b957f07 100644 --- a/src/frontend/xgboost.cc +++ b/src/frontend/xgboost.cc @@ -392,44 +392,50 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { bool need_transform_to_margin = mparam_.major_version >= 1; /* 2. Export model */ - treelite::Model model; - model.num_feature = mparam_.num_feature; - model.num_output_group = std::max(mparam_.num_class, 1); - model.random_forest_flag = false; + treelite::Model model_wrapper = treelite::Model::Create(); + treelite::ModelImpl& model_handle = model_wrapper.GetImpl(); + model_handle.num_feature = static_cast(mparam_.num_feature); + model_handle.num_output_group = std::max(mparam_.num_class, 1); + model_handle.random_forest_flag = false; // set global bias - model.param.global_bias = static_cast(mparam_.base_score); + model_handle.param.global_bias = static_cast(mparam_.base_score); std::vector exponential_family { "count:poisson", "reg:gamma", "reg:tweedie" }; if (need_transform_to_margin) { if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - model.param.global_bias = ProbToMargin::Sigmoid(model.param.global_bias); + model_handle.param.global_bias = ProbToMargin::Sigmoid(model_handle.param.global_bias); } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - model.param.global_bias = ProbToMargin::Exponential(model.param.global_bias); + model_handle.param.global_bias = ProbToMargin::Exponential(model_handle.param.global_bias); } } // set correct prediction transform function, depending on objective function if (name_obj_ == "multi:softmax") { - std::strncpy(model.param.pred_transform, "max_index", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "max_index", + sizeof(model_handle.param.pred_transform)); } else if (name_obj_ == "multi:softprob") { - std::strncpy(model.param.pred_transform, "softmax", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "softmax", + sizeof(model_handle.param.pred_transform)); } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = 1.0f; + std::strncpy(model_handle.param.pred_transform, "sigmoid", + sizeof(model_handle.param.pred_transform)); + model_handle.param.sigmoid_alpha = 1.0f; } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - std::strncpy(model.param.pred_transform, "exponential", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "exponential", + sizeof(model_handle.param.pred_transform)); } else { - std::strncpy(model.param.pred_transform, "identity", sizeof(model.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "identity", + sizeof(model_handle.param.pred_transform)); } // traverse trees for (const auto& xgb_tree : xgb_trees_) { - model.trees.emplace_back(); - treelite::Tree& tree = model.trees.back(); + model_handle.trees.emplace_back(); + treelite::Tree& tree = model_handle.trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -444,13 +450,12 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { const NodeStat stat = xgb_tree.Stat(old_id); if (node.is_leaf()) { const bst_float leaf_value = node.leaf_value(); - tree.SetLeaf(new_id, static_cast(leaf_value)); + tree.SetLeaf(new_id, static_cast(leaf_value)); } else { const bst_float split_cond = node.split_cond(); tree.AddChilds(new_id); tree.SetNumericalSplit(new_id, node.split_index(), - static_cast(split_cond), node.default_left(), - treelite::Operator::kLT); + static_cast(split_cond), node.default_left(), treelite::Operator::kLT); tree.SetGain(new_id, stat.loss_chg); Q.push({node.cleft(), tree.LeftChild(new_id)}); Q.push({node.cright(), tree.RightChild(new_id)}); @@ -458,7 +463,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { tree.SetSumHess(new_id, stat.sum_hess); } } - return model; + return model_wrapper; } } // anonymous namespace diff --git a/src/reference_serializer.cc b/src/reference_serializer.cc index 6d60fb9b..a7fc7b44 100644 --- a/src/reference_serializer.cc +++ b/src/reference_serializer.cc @@ -36,7 +36,8 @@ struct Handler> { namespace treelite { -void Tree::ReferenceSerialize(dmlc::Stream* fo) const { +template +void Tree::ReferenceSerialize(dmlc::Stream* fo) const { fo->Write(num_nodes); fo->Write(leaf_vector_); fo->Write(leaf_vector_offset_); @@ -54,14 +55,15 @@ void Tree::ReferenceSerialize(dmlc::Stream* fo) const { CHECK_EQ(left_categories_offset_.Back(), left_categories_.Size()); } -void Model::ReferenceSerialize(dmlc::Stream* fo) const { +template +void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const { fo->Write(num_feature); fo->Write(num_output_group); fo->Write(random_forest_flag); fo->Write(¶m, sizeof(param)); uint64_t sz = static_cast(trees.size()); fo->Write(sz); - for (const Tree& tree : trees) { + for (const Tree& tree : trees) { tree.ReferenceSerialize(fo); } } diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 5b0d4685..2ea94e38 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,7 +1,7 @@ add_executable(treelite_cpp_test) set_target_properties(treelite_cpp_test PROPERTIES - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON) target_link_libraries(treelite_cpp_test PRIVATE objtreelite objtreelite_runtime objtreelite_common GTest::GTest) @@ -17,7 +17,7 @@ endif() target_sources(treelite_cpp_test PRIVATE test_main.cc - test_serializer.cc + #test_serializer.cc ) msvc_use_static_runtime() From edcd1f429a714384260199dfc585ed0011dde953 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sun, 30 Aug 2020 01:18:40 -0700 Subject: [PATCH 02/38] Add back size check for Node class; define format string for Node class with all possible template args --- include/treelite/tree.h | 18 ++++++++++++++---- include/treelite/tree_impl.h | 15 ++++++++++++--- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 65be0f2c..ddf3ccc3 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -134,13 +134,22 @@ class Tree { bool sum_hess_present_; /*! \brief whether gain_present_ field is present */ bool gain_present_; - // padding - uint16_t pad_; }; static_assert(std::is_pod::value, "Node must be a POD type"); - // TODO(hcho3): Add back size check - //static_assert(sizeof(Node) == 48, "Node must be 48 bytes"); + static_assert(std::is_same::value + || std::is_same::value, + "ThresholdType must be either float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "LeafOutputType must be either uint32_t, float32 or float64"); + static_assert(!std::is_same::value + || !std::is_same::value, + "LeafOutputType cannot be float64 when ThresholdType is float32"); + static_assert((std::is_same::value && sizeof(Node) == 48) + || (std::is_same::value && sizeof(Node) == 56), + "Node incorrect size"); Tree() = default; ~Tree() = default; @@ -150,6 +159,7 @@ class Tree { Tree& operator=(Tree&&) noexcept = default; inline Tree Clone() const; + inline const char* GetFormatStringForNode(); inline std::vector GetPyBuffer(); inline void InitFromPyBuffer(std::vector frames); diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index b81fe7e7..47d87405 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -360,6 +360,16 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { *scalar = *t; } +template +inline const char* +Tree::GetFormatStringForNode() { + if (std::is_same::value) { + return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}"; + } else { + return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}"; + } +} + constexpr size_t kNumFramePerTree = 6; template @@ -367,7 +377,7 @@ inline std::vector Tree::GetPyBuffer() { return { GetPyBufferFromScalar(&num_nodes), - GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}"), + GetPyBufferFromArray(&nodes_, GetFormatStringForNode()), GetPyBufferFromArray(&leaf_vector_), GetPyBufferFromArray(&leaf_vector_offset_), GetPyBufferFromArray(&left_categories_), @@ -405,7 +415,7 @@ ModelImpl::GetPyBuffer() { }; /* Body */ - for (auto& tree : trees) { + for (Tree& tree : trees) { auto tree_frames = tree.GetPyBuffer(); frames.insert(frames.end(), tree_frames.begin(), tree_frames.end()); } @@ -447,7 +457,6 @@ inline void Tree::Node::Init() { data_count_present_ = sum_hess_present_ = gain_present_ = false; split_type_ = SplitFeatureType::kNone; cmp_ = Operator::kNone; - pad_ = 0; } template From c0d1abd1daaee4ee2d82bb65aa9f230c5419f019 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 02:52:17 -0700 Subject: [PATCH 03/38] Update ModelBuilder interface + zero-copy serializer --- include/treelite/base.h | 53 ++++- include/treelite/frontend.h | 81 ++++++-- include/treelite/frontend_impl.h | 51 +++++ include/treelite/tree.h | 25 ++- include/treelite/tree_impl.h | 215 +++++++++++++------- src/CMakeLists.txt | 3 +- src/frontend/builder.cc | 334 ++++++++++++++++++++----------- src/reference_serializer.cc | 10 + tests/cpp/CMakeLists.txt | 2 +- tests/cpp/test_serializer.cc | 70 ++++--- 10 files changed, 594 insertions(+), 250 deletions(-) create mode 100644 include/treelite/frontend_impl.h diff --git a/include/treelite/base.h b/include/treelite/base.h index b9a134f3..cca5682c 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -29,11 +29,22 @@ enum class Operator : int8_t { kGT, /*!< operator > */ kGE, /*!< operator >= */ }; + +/*! \brief Types used by thresholds and leaf outputs */ +enum class TypeInfo : uint8_t { + kInvalid = 0, + kUInt32 = 1, + kFloat32 = 2, + kFloat64 = 3 +}; +static_assert(std::is_same::type, uint8_t>::value, + "TypeInfo must use uint8_t as underlying type"); + /*! \brief conversion table from string to operator, defined in optable.cc */ extern const std::unordered_map optable; /*! - * \brief get string representation of comparsion operator + * \brief get string representation of comparison operator * \param op comparison operator * \return string representation */ @@ -48,6 +59,46 @@ inline std::string OpName(Operator op) { } } +/*! + * \brief get string representation of type info + * \param info a type info + * \return string representation + */ +inline std::string TypeInfoToString(treelite::TypeInfo type) { + switch (type) { + case treelite::TypeInfo::kInvalid: + return "invalid"; + case treelite::TypeInfo::kUInt32: + return "uint32"; + case treelite::TypeInfo::kFloat32: + return "float32"; + case treelite::TypeInfo::kFloat64: + return "float64"; + default: + throw std::runtime_error("Unrecognized type"); + return ""; + } +} + +/*! + * \brief Convert a template type into a type info + * \tparam template type to be converted + * \return TypeInfo corresponding to the template type arg + */ +template +inline TypeInfo InferTypeInfoOf() { + if (std::is_same::value) { + return TypeInfo::kUInt32; + } else if (std::is_same::value) { + return TypeInfo::kFloat32; + } else if (std::is_same::value) { + return TypeInfo::kFloat64; + } else { + throw std::runtime_error(std::string("Unrecognized Value type") + typeid(T).name()); + return TypeInfo::kInvalid; + } +} + /*! * \brief perform comparison between two float's using a comparsion operator * The comparison will be in the form [lhs] [op] [rhs]. diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 67272ce5..9f7a815e 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -46,14 +46,47 @@ void LoadXGBoostModel(const void* buf, size_t len, Model* out); //-------------------------------------------------------------------------- // model builder interface: build trees incrementally //-------------------------------------------------------------------------- -struct TreeBuilderImpl; // forward declaration -struct ModelBuilderImpl; // ditto -class ModelBuilder; // ditto + +/* forward declarations */ +struct TreeBuilderImpl; +struct ModelBuilderImpl; +class ModelBuilder; + +class Value { + private: + std::shared_ptr handle_; + TypeInfo type_; + public: + Value(); + ~Value() = default; + Value(const Value&) = default; + Value(Value&&) noexcept = default; + Value& operator=(const Value&) = default; + Value& operator=(Value&&) noexcept = default; + template + static Value Create(T init_value); + template + T& Get(); + template + const T& Get() const; + template + inline auto Dispatch(Func func); + template + inline auto Dispatch(Func func) const; + TypeInfo GetValueType() const; +}; /*! \brief tree builder class */ class TreeBuilder { public: - TreeBuilder(); // constructor + /*! + * \brief Constructor + * \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model + * must have the same type. + * \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the + * same type. + */ + TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type); // constructor ~TreeBuilder(); // destructor // this class is only move-constructible and move-assignable TreeBuilder(const TreeBuilder&) = delete; @@ -88,12 +121,10 @@ class TreeBuilder { * \param left_child_key unique integer key to identify the left child node * \param right_child_key unique integer key to identify the right child node */ - void SetNumericalTestNode(int node_key, unsigned feature_id, - const char* op, tl_float threshold, bool default_left, - int left_child_key, int right_child_key); - void SetNumericalTestNode(int node_key, unsigned feature_id, - Operator op, tl_float threshold, bool default_left, - int left_child_key, int right_child_key); + void SetNumericalTestNode(int node_key, unsigned feature_id, const char* op, Value threshold, + bool default_left, int left_child_key, int right_child_key); + void SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold, + bool default_left, int left_child_key, int right_child_key); /*! * \brief Turn an empty node into a categorical test node. * A list defines all categories that would be classified as the left side. @@ -107,18 +138,16 @@ class TreeBuilder { * \param left_child_key unique integer key to identify the left child node * \param right_child_key unique integer key to identify the right child node */ - void SetCategoricalTestNode(int node_key, - unsigned feature_id, - const std::vector& left_categories, - bool default_left, int left_child_key, - int right_child_key); + void SetCategoricalTestNode(int node_key, unsigned feature_id, + const std::vector& left_categories, bool default_left, + int left_child_key, int right_child_key); /*! * \brief Turn an empty node into a leaf node * \param node_key unique integer key to identify the node being modified; * this node needs to be empty * \param leaf_value leaf value (weight) of the leaf node */ - void SetLeafNode(int node_key, tl_float leaf_value); + void SetLeafNode(int node_key, Value leaf_value); /*! * \brief Turn an empty node into a leaf vector node * The leaf vector (collection of multiple leaf weights per leaf node) is @@ -127,13 +156,13 @@ class TreeBuilder { * this node needs to be empty * \param leaf_vector leaf vector of the leaf node */ - void SetLeafVectorNode(int node_key, - const std::vector& leaf_vector); + void SetLeafVectorNode(int node_key, const std::vector& leaf_vector); private: - std::unique_ptr pimpl; // Pimpl pattern - void* ensemble_id; // id of ensemble (nullptr if not part of any) + std::unique_ptr pimpl_; // Pimpl pattern + ModelBuilder* ensemble_id_; // id of ensemble (nullptr if not part of any) friend class ModelBuilder; + friend struct ModelBuilderImpl; }; /*! \brief model builder class */ @@ -149,8 +178,13 @@ class ModelBuilder { * classification * \param random_forest_flag whether the model is a random forest (true) or * gradient boosted trees (false) + * \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model + * must have the same type. + * \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the + * same type. */ - ModelBuilder(int num_feature, int num_output_group, bool random_forest_flag); + ModelBuilder(int num_feature, int num_output_group, bool random_forest_flag, + TypeInfo threshold_type, TypeInfo leaf_output_type); ~ModelBuilder(); // destructor /*! * \brief Set a model parameter @@ -190,9 +224,12 @@ class ModelBuilder { void CommitModel(Model* out_model); private: - std::unique_ptr pimpl; // Pimpl pattern + std::unique_ptr pimpl_; // Pimpl pattern }; } // namespace frontend } // namespace treelite + +#include "frontend_impl.h" + #endif // TREELITE_FRONTEND_H_ diff --git a/include/treelite/frontend_impl.h b/include/treelite/frontend_impl.h new file mode 100644 index 00000000..4d5fb1c8 --- /dev/null +++ b/include/treelite/frontend_impl.h @@ -0,0 +1,51 @@ +/*! + * Copyright (c) 2020 by Contributors + * \file frontend_impl.h + * \brief Implementation for frontend.h + * \author Hyunsu Cho + */ + +#ifndef TREELITE_FRONTEND_IMPL_H_ +#define TREELITE_FRONTEND_IMPL_H_ + +namespace treelite { +namespace frontend { + +template +inline auto +Value::Dispatch(Func func) { + switch(type_) { + case TypeInfo::kUInt32: + return func(Get()); + case TypeInfo::kFloat32: + return func(Get()); + case TypeInfo::kFloat64: + return func(Get()); + default: + throw std::runtime_error(std::string("Unknown value type detected: ") + + std::to_string(static_cast(type_))); + return func(Get()); // avoid "missing return" warning + } +} + +template +inline auto +Value::Dispatch(Func func) const { + switch(type_) { + case TypeInfo::kUInt32: + return func(Get()); + case TypeInfo::kFloat32: + return func(Get()); + case TypeInfo::kFloat64: + return func(Get()); + default: + throw std::runtime_error(std::string("Unknown value type detected: ") + + std::to_string(static_cast(type_))); + return func(Get()); // avoid "missing return" warning + } +} + +} // namespace frontend +} // namespace treelite + +#endif // TREELITE_FRONTEND_IMPL_H_ diff --git a/include/treelite/tree.h b/include/treelite/tree.h index ddf3ccc3..b1087f81 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -160,8 +160,9 @@ class Tree { inline Tree Clone() const; inline const char* GetFormatStringForNode(); - inline std::vector GetPyBuffer(); - inline void InitFromPyBuffer(std::vector frames); + inline void GetPyBuffer(std::vector* dest); + inline void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end); private: // vector of nodes @@ -332,7 +333,7 @@ class Tree { * \param nid ID of node being updated * \param leaf_vector leaf vector */ - inline void SetLeafVector(int nid, const std::vector& leaf_vector); + inline void SetLeafVector(int nid, const std::vector& leaf_vector); /*! * \brief set the hessian sum of the node * \param nid ID of node being updated @@ -446,8 +447,9 @@ struct ModelImpl { void ReferenceSerialize(dmlc::Stream* fo) const; - inline std::vector GetPyBuffer(); - inline void InitFromPyBuffer(std::vector frames); + inline void GetPyBuffer(std::vector* dest); + inline void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end); inline ModelImpl Clone() const; }; @@ -456,7 +458,7 @@ enum class ModelType : uint16_t { kInvalid = 0, kFloat32ThresholdUInt32LeafOutput = 1, kFloat32ThresholdFloat32LeafOutput = 2, - kFloat64ThresholdUint32LeafOutput = 3, + kFloat64ThresholdUInt32LeafOutput = 3, kFloat64ThresholdFloat64LeafOutput = 4 }; @@ -464,18 +466,23 @@ struct Model { private: std::shared_ptr handle_; ModelType type_; + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; public: + template + inline static ModelType InferModelTypeOf(); template inline static Model Create(); + inline static Model Create(TypeInfo threshold_type, TypeInfo leaf_output_type); template inline ModelImpl& GetImpl(); template inline const ModelImpl& GetImpl() const; inline ModelType GetModelType() const; template - inline auto Dispatch(Func func) const; - template inline auto Dispatch(Func func); + template + inline auto Dispatch(Func func) const; inline ModelParam GetParam() const; inline int GetNumFeature() const; inline int GetNumOutputGroup() const; @@ -484,7 +491,7 @@ struct Model { inline void SetTreeLimit(size_t limit); inline void ReferenceSerialize(dmlc::Stream* fo) const; inline std::vector GetPyBuffer(); - inline void InitFromPyBuffer(std::vector frames); + inline static Model CreateFromPyBuffer(std::vector frames); }; } // namespace treelite diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 47d87405..7c7131fd 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -332,6 +332,11 @@ inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) { return GetPyBufferFromScalar(static_cast(scalar), format, sizeof(T)); } +inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) { + using T = std::underlying_type::type; + return GetPyBufferFromScalar(reinterpret_cast(scalar), InferFormatString()); +} + template inline PyBufferFrame GetPyBufferFromScalar(T* scalar) { static_assert(std::is_arithmetic::value, @@ -348,6 +353,18 @@ inline void InitArrayFromPyBuffer(ContiguousArray* vec, PyBufferFrame buffer) vec->UseForeignBuffer(buffer.buf, buffer.nitem); } +inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) { + using T = std::underlying_type::type; + if (sizeof(T) != buffer.itemsize) { + throw std::runtime_error("Incorrect itemsize"); + } + if (buffer.nitem != 1) { + throw std::runtime_error("nitem must be 1 for a scalar"); + } + T* t = static_cast(buffer.buf); + *scalar = static_cast(*t); +} + template inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { if (sizeof(T) != buffer.itemsize) { @@ -373,75 +390,72 @@ Tree::GetFormatStringForNode() { constexpr size_t kNumFramePerTree = 6; template -inline std::vector -Tree::GetPyBuffer() { - return { - GetPyBufferFromScalar(&num_nodes), - GetPyBufferFromArray(&nodes_, GetFormatStringForNode()), - GetPyBufferFromArray(&leaf_vector_), - GetPyBufferFromArray(&leaf_vector_offset_), - GetPyBufferFromArray(&left_categories_), - GetPyBufferFromArray(&left_categories_offset_) - }; +inline void +Tree::GetPyBuffer(std::vector* dest) { + dest->push_back(GetPyBufferFromScalar(&num_nodes)); + dest->push_back(GetPyBufferFromArray(&nodes_, GetFormatStringForNode())); + dest->push_back(GetPyBufferFromArray(&leaf_vector_)); + dest->push_back(GetPyBufferFromArray(&leaf_vector_offset_)); + dest->push_back(GetPyBufferFromArray(&left_categories_)); + dest->push_back(GetPyBufferFromArray(&left_categories_offset_)); } template inline void -Tree::InitFromPyBuffer(std::vector frames) { - size_t frame_id = 0; - InitScalarFromPyBuffer(&num_nodes, frames[frame_id++]); - InitArrayFromPyBuffer(&nodes_, frames[frame_id++]); +Tree::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { + if (std::distance(begin, end) != kNumFramePerTree) { + throw std::runtime_error("Wrong number of frames specified"); + } + InitScalarFromPyBuffer(&num_nodes, *begin++); + InitArrayFromPyBuffer(&nodes_, *begin++); if (num_nodes != nodes_.Size()) { throw std::runtime_error("Could not load the correct number of nodes"); } - InitArrayFromPyBuffer(&leaf_vector_, frames[frame_id++]); - InitArrayFromPyBuffer(&leaf_vector_offset_, frames[frame_id++]); - InitArrayFromPyBuffer(&left_categories_, frames[frame_id++]); - InitArrayFromPyBuffer(&left_categories_offset_, frames[frame_id++]); - if (frame_id != kNumFramePerTree) { - throw std::runtime_error("Wrong number of frames loaded"); - } + InitArrayFromPyBuffer(&leaf_vector_, *begin++); + InitArrayFromPyBuffer(&leaf_vector_offset_, *begin++); + InitArrayFromPyBuffer(&left_categories_, *begin++); + InitArrayFromPyBuffer(&left_categories_offset_, *begin++); } template -inline std::vector -ModelImpl::GetPyBuffer() { +inline void +ModelImpl::GetPyBuffer(std::vector* dest) { /* Header */ - std::vector frames{ - GetPyBufferFromScalar(&num_feature), - GetPyBufferFromScalar(&num_output_group), - GetPyBufferFromScalar(&random_forest_flag), - GetPyBufferFromScalar(¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}") - }; + dest->push_back(GetPyBufferFromScalar(&num_feature)); + dest->push_back(GetPyBufferFromScalar(&num_output_group)); + dest->push_back(GetPyBufferFromScalar(&random_forest_flag)); + dest->push_back(GetPyBufferFromScalar( + ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); /* Body */ for (Tree& tree : trees) { - auto tree_frames = tree.GetPyBuffer(); - frames.insert(frames.end(), tree_frames.begin(), tree_frames.end()); + tree.GetPyBuffer(dest); } - return frames; } template inline void -ModelImpl::InitFromPyBuffer(std::vector frames) { +ModelImpl::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { + const size_t num_frame = std::distance(begin, end); /* Header */ - size_t frame_id = 0; - InitScalarFromPyBuffer(&num_feature, frames[frame_id++]); - InitScalarFromPyBuffer(&num_output_group, frames[frame_id++]); - InitScalarFromPyBuffer(&random_forest_flag, frames[frame_id++]); - InitScalarFromPyBuffer(¶m, frames[frame_id++]); + constexpr size_t kNumFrameInHeader = 4; + if (num_frame < kNumFrameInHeader) { + throw std::runtime_error("Wrong number of frames"); + } + InitScalarFromPyBuffer(&num_feature, *begin++); + InitScalarFromPyBuffer(&num_output_group, *begin++); + InitScalarFromPyBuffer(&random_forest_flag, *begin++); + InitScalarFromPyBuffer(¶m, *begin++); /* Body */ - const size_t num_frame = frames.size(); - if ((num_frame - frame_id) % kNumFramePerTree != 0) { + if ((num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) { throw std::runtime_error("Wrong number of frames"); } trees.clear(); - for (; frame_id < num_frame; frame_id += kNumFramePerTree) { - std::vector tree_frames(frames.begin() + frame_id, - frames.begin() + frame_id + kNumFramePerTree); + for (; begin < end; begin += kNumFramePerTree) { trees.emplace_back(); - trees.back().InitFromPyBuffer(tree_frames); + trees.back().InitFromPyBuffer(begin, begin + kNumFramePerTree); } } @@ -732,7 +746,7 @@ Tree::SetLeaf(int nid, LeafOutputType value) { template inline void Tree::SetLeafVector( - int nid, const std::vector& node_leaf_vector) { + int nid, const std::vector& node_leaf_vector) { const size_t end_oft = leaf_vector_offset_.Back(); const size_t new_end_oft = end_oft + node_leaf_vector.size(); if (end_oft != leaf_vector_.Size()) { @@ -812,12 +826,8 @@ Model::GetModelType() const { } template -inline Model -Model::Create() { - Model model; - model.handle_.reset(new ModelImpl()); - model.type_ = ModelType::kInvalid; - +inline ModelType +Model::InferModelTypeOf() { const char* error_msg = "Unsupported combination of ThresholdType and LeafOutputType"; static_assert(std::is_same::value || std::is_same::value, @@ -828,58 +838,113 @@ Model::Create() { "LeafOutputType should be uint32, float32 or float64"); if (std::is_same::value) { if (std::is_same::value) { - model.type_ = ModelType::kFloat32ThresholdUInt32LeafOutput; + return ModelType::kFloat32ThresholdUInt32LeafOutput; } else if (std::is_same::value) { - model.type_ = ModelType::kFloat32ThresholdFloat32LeafOutput; + return ModelType::kFloat32ThresholdFloat32LeafOutput; } else { throw std::runtime_error(error_msg); } } else if (std::is_same::value) { if (std::is_same::value) { - model.type_ = ModelType::kFloat64ThresholdUint32LeafOutput; + return ModelType::kFloat64ThresholdUInt32LeafOutput; } else if (std::is_same::value) { - model.type_ = ModelType::kFloat64ThresholdFloat64LeafOutput; + return ModelType::kFloat64ThresholdFloat64LeafOutput; } else { throw std::runtime_error(error_msg); } - } else { - throw std::runtime_error(error_msg); } + throw std::runtime_error(error_msg); + return ModelType::kInvalid; +} + +template +inline Model +Model::Create() { + Model model; + model.handle_.reset(new ModelImpl()); + model.type_ = InferModelTypeOf(); + model.threshold_type_ = InferTypeInfoOf(); + model.leaf_output_type_ = InferTypeInfoOf(); return model; } +inline Model +Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { + auto error_threshold_type = [threshold_type]() { + std::ostringstream oss; + oss << "Invalid threshold type: " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + auto error_leaf_output_type = [threshold_type, leaf_output_type]() { + std::ostringstream oss; + oss << "Cannot use leaf output type " << treelite::TypeInfoToString(leaf_output_type) + << "with threshold type " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + switch (threshold_type) { + case treelite::TypeInfo::kFloat32: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return treelite::Model::Create(); + case treelite::TypeInfo::kFloat32: + return treelite::Model::Create(); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + case treelite::TypeInfo::kFloat64: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return treelite::Model::Create(); + case treelite::TypeInfo::kFloat64: + return treelite::Model::Create(); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + default: + throw std::runtime_error(error_threshold_type()); + break; + } + return treelite::Model(); // avoid missing return value warning +} + template inline auto -Model::Dispatch(Func func) const { +Model::Dispatch(Func func) { switch(type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat32ThresholdFloat32LeafOutput: return func(GetImpl()); - case ModelType::kFloat64ThresholdUint32LeafOutput: + case ModelType::kFloat64ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat64ThresholdFloat64LeafOutput: return func(GetImpl()); default: - throw std::runtime_error("Unknown type name"); + throw std::runtime_error(std::string("Unknown type detected: ") + + std::to_string(static_cast(type_))); return func(GetImpl()); // avoid "missing return" warning } } template inline auto -Model::Dispatch(Func func) { +Model::Dispatch(Func func) const { switch(type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat32ThresholdFloat32LeafOutput: return func(GetImpl()); - case ModelType::kFloat64ThresholdUint32LeafOutput: + case ModelType::kFloat64ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat64ThresholdFloat64LeafOutput: return func(GetImpl()); default: - throw std::runtime_error("Unknown type name"); + throw std::runtime_error(std::string("Unknown type detected: ") + + std::to_string(static_cast(type_))); return func(GetImpl()); // avoid "missing return" warning } } @@ -921,12 +986,28 @@ Model::ReferenceSerialize(dmlc::Stream* fo) const { inline std::vector Model::GetPyBuffer() { - return Dispatch([](auto& handle) { return handle.GetPyBuffer(); }); + std::vector buffer; + buffer.push_back(GetPyBufferFromScalar(&threshold_type_)); + buffer.push_back(GetPyBufferFromScalar(&leaf_output_type_)); + Dispatch([&buffer](auto& handle) { return handle.GetPyBuffer(&buffer); }); + return buffer; } -inline void -Model::InitFromPyBuffer(std::vector frames) { - Dispatch([&frames](auto& handle) { handle.InitFromPyBuffer(frames); }); +inline Model +Model::CreateFromPyBuffer(std::vector frames) { + using TypeInfoInt = std::underlying_type::type; + TypeInfo threshold_type, leaf_output_type; + if (frames.size() < 2) { + throw std::runtime_error("Insufficient number of frames: there must be at least two"); + } + InitScalarFromPyBuffer(&threshold_type, frames[0]); + InitScalarFromPyBuffer(&leaf_output_type, frames[1]); + + Model model = Model::Create(threshold_type, leaf_output_type); + model.Dispatch([&frames](auto& handle) { + handle.InitFromPyBuffer(frames.begin() + 2, frames.end()); + }); + return model; } inline void InitParamAndCheck(ModelParam* param, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c1fb1efa..df239ddd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -81,7 +81,7 @@ target_sources(objtreelite compiler/failsafe.cc compiler/pred_transform.cc compiler/pred_transform.h - #frontend/builder.cc + frontend/builder.cc frontend/lightgbm.cc frontend/xgboost.cc annotator.cc @@ -97,6 +97,7 @@ target_sources(objtreelite ${PROJECT_SOURCE_DIR}/include/treelite/data.h ${PROJECT_SOURCE_DIR}/include/treelite/filesystem.h ${PROJECT_SOURCE_DIR}/include/treelite/frontend.h + ${PROJECT_SOURCE_DIR}/include/treelite/frontend_impl.h ${PROJECT_SOURCE_DIR}/include/treelite/omp.h ${PROJECT_SOURCE_DIR}/include/treelite/tree.h ${PROJECT_SOURCE_DIR}/include/treelite/tree_impl.h diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index 5b7ff227..0c365d3d 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -11,33 +11,30 @@ #include #include -/* data structures with underscore prefixes are internal use only and - do not have external linkage */ +/* data structures with underscore prefixes are internal use only and don't have external linkage */ namespace { -struct _Node { - enum class _Status : int8_t { +struct NodeDraft { + enum class Status : int8_t { kEmpty, kNumericalTest, kCategoricalTest, kLeaf }; - union _Info { - treelite::tl_float leaf_value; // for leaf nodes - treelite::tl_float threshold; // for non-leaf nodes - }; /* * leaf vector: only used for random forests with multi-class classification */ - std::vector leaf_vector; - _Status status; + std::vector leaf_vector; + Status status; /* pointers to parent, left and right children */ - _Node* parent; - _Node* left_child; - _Node* right_child; + NodeDraft* parent; + NodeDraft* left_child; + NodeDraft* right_child; // split feature index unsigned feature_id; // default direction for missing values bool default_left; - // extra info: leaf value or threshold - _Info info; + // leaf value (only for leaf nodes) + treelite::frontend::Value leaf_value; + // threshold (only for non-leaf nodes) + treelite::frontend::Value threshold; // (for numerical split) // operator to use for expression of form [fval] OP [threshold] // If the expression evaluates to true, take the left child; @@ -50,15 +47,17 @@ struct _Node { // categories in that particular feature. Let's assume n <= 64. std::vector left_categories; - inline _Node() - : status(_Status::kEmpty), - parent(nullptr), left_child(nullptr), right_child(nullptr) {} + inline NodeDraft() + : status(Status::kEmpty), parent(nullptr), left_child(nullptr), right_child(nullptr) {} }; -struct _Tree { - _Node* root; - std::unordered_map> nodes; - inline _Tree() : root(nullptr), nodes() {} +struct TreeDraft { + NodeDraft* root; + std::unordered_map> nodes; + treelite::TypeInfo threshold_type; + treelite::TypeInfo leaf_output_type; + inline TreeDraft(treelite::TypeInfo threshold_type, treelite::TypeInfo leaf_output_type) + : root(nullptr), nodes(), threshold_type(threshold_type), leaf_output_type(leaf_output_type) {} }; } // anonymous namespace @@ -69,8 +68,9 @@ namespace frontend { DMLC_REGISTRY_FILE_TAG(builder); struct TreeBuilderImpl { - _Tree tree; - inline TreeBuilderImpl() : tree() {} + TreeDraft tree; + inline TreeBuilderImpl(TypeInfo threshold_type, TypeInfo leaf_output_type) + : tree(threshold_type, leaf_output_type) {} }; struct ModelBuilderImpl { @@ -78,35 +78,89 @@ struct ModelBuilderImpl { int num_feature; int num_output_group; bool random_forest_flag; + TypeInfo threshold_type; + TypeInfo leaf_output_type; std::vector> cfg; - inline ModelBuilderImpl(int num_feature, int num_output_group, - bool random_forest_flag) - : trees(), num_feature(num_feature), - num_output_group(num_output_group), - random_forest_flag(random_forest_flag), cfg() { + inline ModelBuilderImpl(int num_feature, int num_output_group, bool random_forest_flag, + TypeInfo threshold_type, TypeInfo leaf_output_type) + : trees(), num_feature(num_feature), num_output_group(num_output_group), + random_forest_flag(random_forest_flag), threshold_type(threshold_type), + leaf_output_type(leaf_output_type), cfg() { CHECK_GT(num_feature, 0) << "ModelBuilder: num_feature must be positive"; - CHECK_GT(num_output_group, 0) - << "ModelBuilder: num_output_group must be positive"; + CHECK_GT(num_output_group, 0) << "ModelBuilder: num_output_group must be positive"; + CHECK(threshold_type != TypeInfo::kInvalid) + << "ModelBuilder: threshold_type can't be invalid"; + CHECK(leaf_output_type != TypeInfo::kInvalid) + << "ModelBuilder: leaf_output_type can't be invalid"; } + // Templatized implementation of CommitModel() + template + void CommitModelImpl(ModelImpl* out_model); }; -TreeBuilder::TreeBuilder() - : pimpl(new TreeBuilderImpl()), ensemble_id(nullptr) {} -TreeBuilder::~TreeBuilder() {} +template +void SetLeafVector(Tree* tree, int nid, + const std::vector& leaf_vector) { + const size_t leaf_vector_size = leaf_vector.size(); + const TypeInfo expected_leaf_type = InferTypeInfoOf(); + std::vector out_leaf_vector; + for (size_t i = 0; i < leaf_vector_size; ++i) { + const Value& leaf_value = leaf_vector[i]; + CHECK(leaf_value.GetValueType() == expected_leaf_type) + << "Leaf value at index " << i << " has incorrect type. Expected: " + << TypeInfoToString(expected_leaf_type) << ", Given: " + << TypeInfoToString(leaf_value.GetValueType()); + out_leaf_vector.push_back(leaf_value.Get()); + } + tree->SetLeafVector(nid, out_leaf_vector); +} + +Value::Value() : handle_(nullptr), type_(TypeInfo::kInvalid) {} + +template +Value +Value::Create(T init_value) { + Value val; + std::unique_ptr ptr = std::make_unique(init_value); + val.handle_.reset(ptr.release()); + val.type_ = InferTypeInfoOf(); + return val; +} + +template +T& +Value::Get() { + return *static_cast(handle_.get()); +} + +template +const T& +Value::Get() const { + return *static_cast(handle_.get()); +} + +TypeInfo +Value::GetValueType() const { + return type_; +} + +TreeBuilder::TreeBuilder(TypeInfo threshold_type, TypeInfo leaf_output_type) + : pimpl_(new TreeBuilderImpl(threshold_type, leaf_output_type)), ensemble_id_(nullptr) {} +TreeBuilder::~TreeBuilder() = default; void TreeBuilder::CreateNode(int node_key) { - auto& nodes = pimpl->tree.nodes; + auto& nodes = pimpl_->tree.nodes; CHECK_EQ(nodes.count(node_key), 0) << "CreateNode: nodes with duplicate keys are not allowed"; - nodes[node_key].reset(new _Node()); + nodes[node_key] = std::make_unique(); } void TreeBuilder::DeleteNode(int node_key) { - auto& tree = pimpl->tree; + auto& tree = pimpl_->tree; auto& nodes = tree.nodes; CHECK_GT(nodes.count(node_key), 0) << "DeleteNode: no node found with node_key"; - _Node* node = nodes[node_key].get(); + NodeDraft* node = nodes[node_key].get(); if (tree.root == node) { // deleting root tree.root = nullptr; } @@ -126,43 +180,42 @@ TreeBuilder::DeleteNode(int node_key) { void TreeBuilder::SetRootNode(int node_key) { - auto& tree = pimpl->tree; + auto& tree = pimpl_->tree; auto& nodes = tree.nodes; CHECK_GT(nodes.count(node_key), 0) << "SetRootNode: no node found with node_key"; - _Node* node = nodes[node_key].get(); + NodeDraft* node = nodes[node_key].get(); CHECK(!node->parent) << "SetRootNode: a root node cannot have a parent"; tree.root = node; } void -TreeBuilder::SetNumericalTestNode(int node_key, - unsigned feature_id, - const char* opname, tl_float threshold, - bool default_left, int left_child_key, +TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, const char* opname, + Value threshold, bool default_left, int left_child_key, int right_child_key) { CHECK_GT(optable.count(opname), 0) << "No operator \"" << opname << "\" exists"; Operator op = optable.at(opname); - SetNumericalTestNode(node_key, feature_id, op, threshold, default_left, + SetNumericalTestNode(node_key, feature_id, op, std::move(threshold), default_left, left_child_key, right_child_key); } void -TreeBuilder::SetNumericalTestNode(int node_key, - unsigned feature_id, - Operator op, tl_float threshold, - bool default_left, int left_child_key, - int right_child_key) { - auto& tree = pimpl->tree; +TreeBuilder::SetNumericalTestNode(int node_key, unsigned feature_id, Operator op, Value threshold, + bool default_left, int left_child_key, int right_child_key) { + auto& tree = pimpl_->tree; auto& nodes = tree.nodes; + CHECK(tree.threshold_type == threshold.GetValueType()) + << "SetNumericalTestNode: threshold has an incorrect type. " + << "Expected: " << TypeInfoToString(tree.threshold_type) + << ", Given: " << TypeInfoToString(threshold.GetValueType()); CHECK_GT(nodes.count(node_key), 0) << "SetNumericalTestNode: no node found with node_key"; CHECK_GT(nodes.count(left_child_key), 0) << "SetNumericalTestNode: no node found with left_child_key"; CHECK_GT(nodes.count(right_child_key), 0) << "SetNumericalTestNode: no node found with right_child_key"; - _Node* node = nodes[node_key].get(); - _Node* left_child = nodes[left_child_key].get(); - _Node* right_child = nodes[right_child_key].get(); - CHECK(node->status == _Node::_Status::kEmpty) + NodeDraft* node = nodes[node_key].get(); + NodeDraft* left_child = nodes[left_child_key].get(); + NodeDraft* right_child = nodes[right_child_key].get(); + CHECK(node->status == NodeDraft::Status::kEmpty) << "SetNumericalTestNode: cannot modify a non-empty node"; CHECK(!left_child->parent) << "SetNumericalTestNode: node designated as left child already has a parent"; @@ -170,34 +223,32 @@ TreeBuilder::SetNumericalTestNode(int node_key, << "SetNumericalTestNode: node designated as right child already has a parent"; CHECK(left_child != tree.root && right_child != tree.root) << "SetNumericalTestNode: the root node cannot be a child"; - node->status = _Node::_Status::kNumericalTest; + node->status = NodeDraft::Status::kNumericalTest; node->left_child = nodes[left_child_key].get(); node->left_child->parent = node; node->right_child = nodes[right_child_key].get(); node->right_child->parent = node; node->feature_id = feature_id; node->default_left = default_left; - node->info.threshold = threshold; + node->threshold = std::move(threshold); node->op = op; } void -TreeBuilder::SetCategoricalTestNode(int node_key, - unsigned feature_id, - const std::vector& left_categories, - bool default_left, int left_child_key, - int right_child_key) { - auto &tree = pimpl->tree; +TreeBuilder::SetCategoricalTestNode(int node_key, unsigned feature_id, + const std::vector& left_categories, bool default_left, + int left_child_key, int right_child_key) { + auto &tree = pimpl_->tree; auto &nodes = tree.nodes; CHECK_GT(nodes.count(node_key), 0) << "SetCategoricalTestNode: no node found with node_key"; CHECK_GT(nodes.count(left_child_key), 0) << "SetCategoricalTestNode: no node found with left_child_key"; CHECK_GT(nodes.count(right_child_key), 0) << "SetCategoricalTestNode: no node found with right_child_key"; - _Node* node = nodes[node_key].get(); - _Node* left_child = nodes[left_child_key].get(); - _Node* right_child = nodes[right_child_key].get(); - CHECK(node->status == _Node::_Status::kEmpty) + NodeDraft* node = nodes[node_key].get(); + NodeDraft* left_child = nodes[left_child_key].get(); + NodeDraft* right_child = nodes[right_child_key].get(); + CHECK(node->status == NodeDraft::Status::kEmpty) << "SetCategoricalTestNode: cannot modify a non-empty node"; CHECK(!left_child->parent) << "SetCategoricalTestNode: node designated as left child already has a parent"; @@ -205,7 +256,7 @@ TreeBuilder::SetCategoricalTestNode(int node_key, << "SetCategoricalTestNode: node designated as right child already has a parent"; CHECK(left_child != tree.root && right_child != tree.root) << "SetCategoricalTestNode: the root node cannot be a child"; - node->status = _Node::_Status::kCategoricalTest; + node->status = NodeDraft::Status::kCategoricalTest; node->left_child = nodes[left_child_key].get(); node->left_child->parent = node; node->right_child = nodes[right_child_key].get(); @@ -216,38 +267,49 @@ TreeBuilder::SetCategoricalTestNode(int node_key, } void -TreeBuilder::SetLeafNode(int node_key, tl_float leaf_value) { - auto& tree = pimpl->tree; +TreeBuilder::SetLeafNode(int node_key, Value leaf_value) { + auto& tree = pimpl_->tree; auto& nodes = tree.nodes; + CHECK(tree.leaf_output_type == leaf_value.GetValueType()) + << "SetLeafNode: leaf_value has an incorrect type. " + << "Expected: " << TypeInfoToString(tree.leaf_output_type) + << ", Given: " << TypeInfoToString(leaf_value.GetValueType()); CHECK_GT(nodes.count(node_key), 0) << "SetLeafNode: no node found with node_key"; - _Node* node = nodes[node_key].get(); - CHECK(node->status == _Node::_Status::kEmpty) << "SetLeafNode: cannot modify a non-empty node"; - node->status = _Node::_Status::kLeaf; - node->info.leaf_value = leaf_value; + NodeDraft* node = nodes[node_key].get(); + CHECK(node->status == NodeDraft::Status::kEmpty) << "SetLeafNode: cannot modify a non-empty node"; + node->status = NodeDraft::Status::kLeaf; + node->leaf_value = std::move(leaf_value); } void -TreeBuilder::SetLeafVectorNode(int node_key, const std::vector& leaf_vector) { - auto& tree = pimpl->tree; +TreeBuilder::SetLeafVectorNode(int node_key, const std::vector& leaf_vector) { + auto& tree = pimpl_->tree; auto& nodes = tree.nodes; + const size_t leaf_vector_len = leaf_vector.size(); + for (size_t i = 0; i < leaf_vector_len; ++i) { + const Value& leaf_value = leaf_vector[i]; + CHECK(tree.leaf_output_type == leaf_value.GetValueType()) + << "SetLeafVectorNode: the element " << i << " in leaf_vector has an incorrect type. " + << "Expected: " << TypeInfoToString(tree.leaf_output_type) + << ", Given: " << TypeInfoToString(leaf_value.GetValueType()); + } CHECK_GT(nodes.count(node_key), 0) << "SetLeafVectorNode: no node found with node_key"; - _Node* node = nodes[node_key].get(); - CHECK(node->status == _Node::_Status::kEmpty) + NodeDraft* node = nodes[node_key].get(); + CHECK(node->status == NodeDraft::Status::kEmpty) << "SetLeafVectorNode: cannot modify a non-empty node"; - node->status = _Node::_Status::kLeaf; + node->status = NodeDraft::Status::kLeaf; node->leaf_vector = leaf_vector; } -ModelBuilder::ModelBuilder(int num_feature, int num_output_group, - bool random_forest_flag) - : pimpl(new ModelBuilderImpl(num_feature, - num_output_group, - random_forest_flag)) {} +ModelBuilder::ModelBuilder(int num_feature, int num_output_group, bool random_forest_flag, + TypeInfo threshold_type, TypeInfo leaf_output_type) + : pimpl_(new ModelBuilderImpl(num_feature, num_output_group, random_forest_flag, + threshold_type, leaf_output_type)) {} ModelBuilder::~ModelBuilder() = default; void ModelBuilder::SetModelParam(const char* name, const char* value) { - pimpl->cfg.emplace_back(name, value); + pimpl_->cfg.emplace_back(name, value); } int @@ -256,18 +318,34 @@ ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) { LOG(FATAL) << "InsertTree: not a valid tree builder"; return -1; } - if (tree_builder->ensemble_id != nullptr) { + if (tree_builder->ensemble_id_ != nullptr) { LOG(FATAL) << "InsertTree: tree is already part of another ensemble"; return -1; } + if (tree_builder->pimpl_->tree.threshold_type != this->pimpl_->threshold_type) { + LOG(FATAL) + << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " + << "member trees to use " << TypeInfoToString(this->pimpl_->threshold_type) + << " type for split thresholds whereas the tree is using " + << TypeInfoToString(tree_builder->pimpl_->tree.threshold_type); + return -1; + } + if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) { + LOG(FATAL) + << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " + << "member trees to use " << TypeInfoToString(this->pimpl_->leaf_output_type) + << " type for leaf outputs whereas the tree is using " + << TypeInfoToString(tree_builder->pimpl_->tree.leaf_output_type); + return -1; + } // check bounds for feature indices - for (const auto& kv : tree_builder->pimpl->tree.nodes) { - const _Node::_Status status = kv.second->status; - if (status == _Node::_Status::kNumericalTest || - status == _Node::_Status::kCategoricalTest) { + for (const auto& kv : tree_builder->pimpl_->tree.nodes) { + const NodeDraft::Status status = kv.second->status; + if (status == NodeDraft::Status::kNumericalTest || + status == NodeDraft::Status::kCategoricalTest) { const int fid = static_cast(kv.second->feature_id); - if (fid < 0 || fid >= pimpl->num_feature) { + if (fid < 0 || fid >= this->pimpl_->num_feature) { LOG(FATAL) << "InsertTree: tree has an invalid split at node " << kv.first << ": feature id " << kv.second->feature_id << " is out of bound"; return -1; @@ -276,18 +354,18 @@ ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) { } // perform insertion - auto& trees = pimpl->trees; + auto& trees = pimpl_->trees; if (index == -1) { trees.push_back(std::move(*tree_builder)); - tree_builder->ensemble_id = static_cast(this); + tree_builder->ensemble_id_ = this; return static_cast(trees.size()); } else { if (static_cast(index) <= trees.size()) { trees.insert(trees.begin() + index, std::move(*tree_builder)); - tree_builder->ensemble_id = static_cast(this); + tree_builder->ensemble_id_ = this; return index; } else { - LOG(FATAL) << "CreateTree: index out of bound"; + LOG(FATAL) << "InsertTree: index out of bound"; return -1; } } @@ -295,29 +373,39 @@ ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) { TreeBuilder* ModelBuilder::GetTree(int index) { - return &pimpl->trees.at(index); + return &pimpl_->trees.at(index); } const TreeBuilder* ModelBuilder::GetTree(int index) const { - return &pimpl->trees.at(index); + return &pimpl_->trees.at(index); } void ModelBuilder::DeleteTree(int index) { - auto& trees = pimpl->trees; + auto& trees = pimpl_->trees; CHECK_LT(static_cast(index), trees.size()) << "DeleteTree: index out of bound"; trees.erase(trees.begin() + index); } void ModelBuilder::CommitModel(Model* out_model) { - Model model; - model.num_feature = pimpl->num_feature; - model.num_output_group = pimpl->num_output_group; - model.random_forest_flag = pimpl->random_forest_flag; + Model& model = *out_model; + model = Model::Create(pimpl_->threshold_type, pimpl_->leaf_output_type); + model.Dispatch([this](auto& model_handle) { + this->pimpl_->CommitModelImpl(&model_handle); + }); +} + +template +void +ModelBuilderImpl::CommitModelImpl(ModelImpl* out_model) { + ModelImpl& model = *out_model; + model.num_feature = this->num_feature; + model.num_output_group = this->num_output_group; + model.random_forest_flag = this->random_forest_flag; // extra parameters - InitParamAndCheck(&model.param, pimpl->cfg); + InitParamAndCheck(&model.param, this->cfg); // flag to check consistent use of leaf vector // 0: no leaf should use leaf vector @@ -325,44 +413,45 @@ ModelBuilder::CommitModel(Model* out_model) { // -1: indeterminate int8_t flag_leaf_vector = -1; - for (const auto& _tree_builder : pimpl->trees) { - const auto& _tree = _tree_builder.pimpl->tree; + for (const auto& tree_builder : this->trees) { + const auto& _tree = tree_builder.pimpl_->tree; CHECK(_tree.root) << "CommitModel: a tree has no root node"; - CHECK(_tree.root->status != _Node::_Status::kEmpty) + CHECK(_tree.root->status != NodeDraft::Status::kEmpty) << "SetRootNode: cannot set an empty node as root"; model.trees.emplace_back(); - Tree& tree = model.trees.back(); + Tree& tree = model.trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield // the monotonic sequence 0, 1, 2, ... - std::queue> Q; // (internal pointer, ID) + std::queue> Q; // (internal pointer, ID) Q.push({_tree.root, 0}); // assign 0 to root while (!Q.empty()) { - const _Node* node; + const NodeDraft* node; int nid; std::tie(node, nid) = Q.front(); Q.pop(); - CHECK(node->status != _Node::_Status::kEmpty) + CHECK(node->status != NodeDraft::Status::kEmpty) << "CommitModel: encountered an empty node in the middle of a tree"; - if (node->status == _Node::_Status::kNumericalTest) { + if (node->status == NodeDraft::Status::kNumericalTest) { CHECK(node->left_child) << "CommitModel: a test node lacks a left child"; CHECK(node->right_child) << "CommitModel: a test node lacks a right child"; CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent"; CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent"; tree.AddChilds(nid); - tree.SetNumericalSplit(nid, node->feature_id, node->info.threshold, - node->default_left, node->op); + node->threshold.Dispatch([&tree, nid, node](const auto& threshold) { + tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op); + }); Q.push({node->left_child, tree.LeftChild(nid)}); Q.push({node->right_child, tree.RightChild(nid)}); - } else if (node->status == _Node::_Status::kCategoricalTest) { + } else if (node->status == NodeDraft::Status::kCategoricalTest) { CHECK(node->left_child) << "CommitModel: a test node lacks a left child"; CHECK(node->right_child) << "CommitModel: a test node lacks a right child"; CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent"; CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent"; tree.AddChilds(nid); - tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, - false, node->left_categories); + tree.SetCategoricalSplit(nid, node->feature_id, node->default_left, false, + node->left_categories); Q.push({node->left_child, tree.LeftChild(nid)}); Q.push({node->right_child, tree.RightChild(nid)}); } else { // leaf node @@ -376,13 +465,15 @@ ModelBuilder::CommitModel(Model* out_model) { CHECK_EQ(node->leaf_vector.size(), model.num_output_group) << "CommitModel: The length of leaf vector must be identical to the number of output " << "groups"; - tree.SetLeafVector(nid, node->leaf_vector); + SetLeafVector(&tree, nid, node->leaf_vector); } else { // ordinary leaf CHECK_NE(flag_leaf_vector, 1) << "CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf " << "vector, *no other* leaf node can use a leaf vector"; flag_leaf_vector = 0; // now no leaf can use leaf vector - tree.SetLeaf(nid, node->info.leaf_value); + node->leaf_value.Dispatch([&tree, nid](const auto& leaf_value) { + tree.SetLeaf(nid, leaf_value); + }); } } } @@ -393,7 +484,7 @@ ModelBuilder::CommitModel(Model* out_model) { CHECK(!model.random_forest_flag) << "To use a random forest for multi-class classification, each leaf node must output a " << "leaf vector specifying a probability distribution"; - CHECK_EQ(pimpl->trees.size() % model.num_output_group, 0) + CHECK_EQ(this->trees.size() % model.num_output_group, 0) << "For multi-class classifiers with gradient boosted trees, the number of trees must be " << "evenly divisible by the number of output groups"; } @@ -405,8 +496,11 @@ ModelBuilder::CommitModel(Model* out_model) { } else { LOG(FATAL) << "Impossible thing happened: model has no leaf node!"; } - *out_model = std::move(model); } +template Value Value::Create(uint32_t init_value); +template Value Value::Create(float init_value); +template Value Value::Create(double init_value); + } // namespace frontend } // namespace treelite diff --git a/src/reference_serializer.cc b/src/reference_serializer.cc index a7fc7b44..53878657 100644 --- a/src/reference_serializer.cc +++ b/src/reference_serializer.cc @@ -68,4 +68,14 @@ void ModelImpl::ReferenceSerialize(dmlc::Stream* } } +template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; +template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; +template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; +template void Tree::ReferenceSerialize(dmlc::Stream* fo) const; + +template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; +template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; +template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; +template void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const; + } // namespace treelite diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 2ea94e38..5a8d5ab4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -17,7 +17,7 @@ endif() target_sources(treelite_cpp_test PRIVATE test_main.cc - #test_serializer.cc + test_serializer.cc ) msvc_use_static_runtime() diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 3dc494fb..37f691e1 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -23,10 +23,9 @@ inline std::string TreeliteToBytes(treelite::Model* model) { inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); - std::unique_ptr received_model{new treelite::Model()}; - received_model->InitFromPyBuffer(buffer); + treelite::Model received_model = treelite::Model::CreateFromPyBuffer(buffer); - ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); + ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(&received_model)); } } // anonymous namespace @@ -35,16 +34,18 @@ namespace treelite { TEST(PyBufferInterfaceRoundTrip, TreeStump) { std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false) + new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) }; - std::unique_ptr tree{new frontend::TreeBuilder()}; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); - tree->SetNumericalTestNode(0, 0, "<", 0.0f, true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); tree->SetRootNode(0); - tree->SetLeafNode(1, -1.0f); - tree->SetLeafNode(2, 1.0f); + tree->SetLeafNode(1, frontend::Value::Create(-1.0f)); + tree->SetLeafNode(2, frontend::Value::Create(1.0f)); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -54,16 +55,20 @@ TEST(PyBufferInterfaceRoundTrip, TreeStump) { TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { std::unique_ptr builder{ - new frontend::ModelBuilder(2, 2, true) + new frontend::ModelBuilder(2, 2, true, TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) }; - std::unique_ptr tree{new frontend::TreeBuilder()}; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); - tree->SetNumericalTestNode(0, 0, "<", 0.0f, true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); tree->SetRootNode(0); - tree->SetLeafVectorNode(1, {-1.0f, 1.0f}); - tree->SetLeafVectorNode(2, {1.0f, -1.0f}); + tree->SetLeafVectorNode(1, {frontend::Value::Create(-1.0f), + frontend::Value::Create(1.0f)}); + tree->SetLeafVectorNode(2, {frontend::Value::Create(1.0f), + frontend::Value::Create(-1.0f)}); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -73,16 +78,18 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false) + new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) }; - std::unique_ptr tree{new frontend::TreeBuilder()}; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); tree->SetCategoricalTestNode(0, 0, {0, 1}, true, 1, 2); tree->SetRootNode(0); - tree->SetLeafNode(1, -1.0f); - tree->SetLeafNode(2, 1.0f); + tree->SetLeafNode(1, frontend::Value::Create(-1.0f)); + tree->SetLeafNode(2, frontend::Value::Create(1.0f)); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -92,23 +99,25 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false) + new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) }; builder->SetModelParam("pred_transform", "sigmoid"); builder->SetModelParam("global_bias", "0.5"); for (int tree_id = 0; tree_id < 2; ++tree_id) { - std::unique_ptr tree{new frontend::TreeBuilder()}; + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + }; for (int i = 0; i < 7; ++i) { tree->CreateNode(i); } - tree->SetNumericalTestNode(0, 0, "<", 0.0f, true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); tree->SetCategoricalTestNode(1, 0, {0, 1}, true, 3, 4); tree->SetCategoricalTestNode(2, 1, {0}, true, 5, 6); tree->SetRootNode(0); - tree->SetLeafNode(3, -2.0f); - tree->SetLeafNode(4, 1.0f); - tree->SetLeafNode(5, -1.0f); - tree->SetLeafNode(6, 2.0f); + tree->SetLeafNode(3, frontend::Value::Create(-2.0f)); + tree->SetLeafNode(4, frontend::Value::Create(1.0f)); + tree->SetLeafNode(5, frontend::Value::Create(-1.0f)); + tree->SetLeafNode(6, frontend::Value::Create(2.0f)); builder->InsertTree(tree.get()); } @@ -121,9 +130,11 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { const int depth = 19; std::unique_ptr builder{ - new frontend::ModelBuilder(3, 1, false) + new frontend::ModelBuilder(3, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + }; + std::unique_ptr tree{ + new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) }; - std::unique_ptr tree{new frontend::TreeBuilder()}; for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { const int nid = (1 << level) - 1 + i; @@ -133,11 +144,12 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { const int nid = (1 << level) - 1 + i; - const float leaf_value = 0.5; + const float leaf_value = 0.5f; if (level == depth) { - tree->SetLeafNode(nid, leaf_value); + tree->SetLeafNode(nid, frontend::Value::Create(leaf_value)); } else { - tree->SetNumericalTestNode(nid, (level % 2), "<", 0.0, true, 2 * nid + 1, 2 * nid + 2); + tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0.0f), + true, 2 * nid + 1, 2 * nid + 2); } } } From acc20bef1e4b111e402de01c064bf8da130bd93e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 11:57:38 -0700 Subject: [PATCH 04/38] Update C interface for model builder API --- include/treelite/base.h | 5 +- include/treelite/c_api.h | 91 ++++++++++++++++-------------- include/treelite/frontend.h | 1 + src/CMakeLists.txt | 2 +- src/c_api/c_api.cc | 102 ++++++++++++++++++---------------- src/frontend/builder.cc | 21 +++++-- src/{optable.cc => tables.cc} | 10 +++- 7 files changed, 135 insertions(+), 97 deletions(-) rename src/{optable.cc => tables.cc} (56%) diff --git a/include/treelite/base.h b/include/treelite/base.h index cca5682c..bc0f16cd 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -40,9 +40,12 @@ enum class TypeInfo : uint8_t { static_assert(std::is_same::type, uint8_t>::value, "TypeInfo must use uint8_t as underlying type"); -/*! \brief conversion table from string to operator, defined in optable.cc */ +/*! \brief conversion table from string to Operator, defined in tables.cc */ extern const std::unordered_map optable; +/*! \brief conversion table from string to TypeInfo, defined in tables.cc */ +extern const std::unordered_map typeinfo_table; + /*! * \brief get string representation of comparison operator * \param op comparison operator diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 8f0f5285..441b7355 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -31,6 +31,8 @@ typedef void* ModelBuilderHandle; typedef void* AnnotationHandle; /*! \brief handle to compiler class */ typedef void* CompilerHandle; +/*! \brief handle to a polymorphic value type, used in the model builder API */ +typedef void* ValueHandle; /*! \} */ /*! @@ -293,12 +295,33 @@ TREELITE_DLL int TreeliteFreeModel(ModelHandle handle); * Model builder interface: build trees incrementally * \{ */ +/*! + * \brief Create a new Value object. Some model builder API functions accept this Value type to + * accommodate values of multiple types. + * \param init_value pointer to the value to be stored + * \param type Type of the value to be stored + * \param out newly created Value object + * \return 0 for success; -1 for failure + */ +TREELITE_DLL int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type, + ValueHandle* out); +/*! + * \brief Delete a Value object from memory + * \param handle pointer to the Value object to be deleted + * \return 0 for success; -1 for failure + */ +TREELITE_DLL int TreeliteTreeBuilderDeleteValue(ValueHandle handle); /*! * \brief Create a new tree builder + * \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model + * must have the same type. + * \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the + * same type. * \param out newly created tree builder * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteCreateTreeBuilder(TreeBuilderHandle* out); +TREELITE_DLL int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type, + TreeBuilderHandle* out); /*! * \brief Delete a tree builder from memory * \param handle tree builder to remove @@ -311,24 +334,21 @@ TREELITE_DLL int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle); * \param node_key unique integer key to identify the new node * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, - int node_key); +TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key); /*! * \brief Remove a node from a tree * \param handle tree builder * \param node_key unique integer key to identify the node to be removed * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, - int node_key); +TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key); /*! * \brief Set a node as the root of a tree * \param handle tree builder * \param node_key unique integer key to identify the root node * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, - int node_key); +TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key); /*! * \brief Turn an empty node into a test node with numerical split. * The test is in the form [feature value] OP [threshold]. Depending on the @@ -345,12 +365,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, * \return 0 for success; -1 for failure */ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode( - TreeBuilderHandle handle, - int node_key, unsigned feature_id, - const char* opname, - float threshold, int default_left, - int left_child_key, - int right_child_key); + TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname, + ValueHandle threshold, int default_left, int left_child_key, int right_child_key); /*! * \brief Turn an empty node into a test node with categorical split. * A list defines all categories that would be classified as the left side. @@ -368,13 +384,9 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode( * \return 0 for success; -1 for failure */ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode( - TreeBuilderHandle handle, - int node_key, unsigned feature_id, - const unsigned int* left_categories, - size_t left_categories_len, - int default_left, - int left_child_key, - int right_child_key); + TreeBuilderHandle handle, int node_key, unsigned feature_id, + const unsigned int* left_categories, size_t left_categories_len, int default_left, + int left_child_key, int right_child_key); /*! * \brief Turn an empty node into a leaf node * \param handle tree builder @@ -383,9 +395,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode( * \param leaf_value leaf value (weight) of the leaf node * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, - int node_key, - float leaf_value); +TREELITE_DLL int TreeliteTreeBuilderSetLeafNode( + TreeBuilderHandle handle, int node_key, ValueHandle leaf_value); /*! * \brief Turn an empty node into a leaf vector node * The leaf vector (collection of multiple leaf weights per leaf node) is @@ -397,29 +408,27 @@ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, * \param leaf_vector_len length of leaf_vector * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, - int node_key, - const float* leaf_vector, - size_t leaf_vector_len); +TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode( + TreeBuilderHandle handle, int node_key, const ValueHandle* leaf_vector, size_t leaf_vector_len); /*! * \brief Create a new model builder - * \param num_feature number of features used in model being built. We assume - * that all feature indices are between 0 and - * (num_feature - 1). - * \param num_output_group number of output groups. Set to 1 for binary - * classification and regression; >1 for multiclass - * classification - * \param random_forest_flag whether the model is a random forest. Set to 0 if - * the model is gradient boosted trees. Any nonzero - * value shall indicate that the model is a - * random forest. + * \param num_feature number of features used in model being built. We assume that all feature + * indices are between 0 and (num_feature - 1). + * \param num_output_group number of output groups. Set to 1 for binary classification and + * regression; >1 for multiclass classification + * \param random_forest_flag whether the model is a random forest. Set to 0 if the model is + * gradient boosted trees. Any nonzero value shall indicate that the + * model is a random forest. + * \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model + * must have the same type. + * \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the + * same type. * \param out newly created model builder * \return 0 for success; -1 for failure */ -TREELITE_DLL int TreeliteCreateModelBuilder(int num_feature, - int num_output_group, - int random_forest_flag, - ModelBuilderHandle* out); +TREELITE_DLL int TreeliteCreateModelBuilder( + int num_feature, int num_output_group, int random_forest_flag, const char* threshold_type, + const char* leaf_output_type, ModelBuilderHandle* out); /*! * \brief Set a model parameter * \param handle model builder diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 9f7a815e..d703cb5d 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -65,6 +65,7 @@ class Value { Value& operator=(Value&&) noexcept = default; template static Value Create(T init_value); + static Value Create(const void* init_value, TypeInfo type); template T& Get(); template diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index df239ddd..244f78b2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,7 +87,7 @@ target_sources(objtreelite annotator.cc data.cc filesystem.cc - optable.cc + tables.cc reference_serializer.cc ${PROJECT_SOURCE_DIR}/include/treelite/annotator.h ${PROJECT_SOURCE_DIR}/include/treelite/base.h diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5aba3bc3..ba0c0652 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -371,10 +371,27 @@ int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) { API_END(); } -#if 0 -int TreeliteCreateTreeBuilder(TreeBuilderHandle* out) { +int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type, ValueHandle* out) { API_BEGIN(); - std::unique_ptr builder{new frontend::TreeBuilder()}; + std::unique_ptr value = std::make_unique(); + *value = frontend::Value::Create(init_value, typeinfo_table.at(type)); + *out = static_cast(value.release()); + API_END(); +} + +int TreeliteTreeBuilderDeleteValue(ValueHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} + +int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type, + TreeBuilderHandle* out) { + API_BEGIN(); + std::unique_ptr builder{ + new frontend::TreeBuilder(typeinfo_table.at(threshold_type), + typeinfo_table.at(leaf_output_type)) + }; *out = static_cast(builder.release()); API_END(); } @@ -387,7 +404,7 @@ int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle) { int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; builder->CreateNode(node_key); API_END(); @@ -395,7 +412,7 @@ int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key) { int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; builder->DeleteNode(node_key); API_END(); @@ -403,36 +420,30 @@ int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key) { int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; builder->SetRootNode(node_key); API_END(); } -int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle, - int node_key, unsigned feature_id, - const char* opname, - float threshold, int default_left, - int left_child_key, - int right_child_key) { +int TreeliteTreeBuilderSetNumericalTestNode( + TreeBuilderHandle handle,int node_key, unsigned feature_id, const char* opname, + ValueHandle threshold, int default_left, int left_child_key, int right_child_key) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; - builder->SetNumericalTestNode(node_key, feature_id, opname, static_cast(threshold), + builder->SetNumericalTestNode(node_key, feature_id, opname, + *static_cast(threshold), (default_left != 0), left_child_key, right_child_key); API_END(); } int TreeliteTreeBuilderSetCategoricalTestNode( - TreeBuilderHandle handle, - int node_key, unsigned feature_id, - const unsigned int* left_categories, - size_t left_categories_len, - int default_left, - int left_child_key, - int right_child_key) { - API_BEGIN(); - auto builder = static_cast(handle); + TreeBuilderHandle handle, int node_key, unsigned feature_id, + const unsigned int* left_categories, size_t left_categories_len, int default_left, + int left_child_key, int right_child_key) { + API_BEGIN(); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; std::vector vec(left_categories_len); for (size_t i = 0; i < left_categories_len; ++i) { @@ -444,38 +455,34 @@ int TreeliteTreeBuilderSetCategoricalTestNode( API_END(); } -int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, float leaf_value) { +int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle, int node_key, ValueHandle leaf_value) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; - builder->SetLeafNode(node_key, static_cast(leaf_value)); + builder->SetLeafNode(node_key, *static_cast(leaf_value)); API_END(); } -int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, - int node_key, - const float* leaf_vector, - size_t leaf_vector_len) { +int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle, int node_key, + const ValueHandle* leaf_vector, size_t leaf_vector_len) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted TreeBuilder object"; - std::vector vec(leaf_vector_len); + std::vector vec(leaf_vector_len); for (size_t i = 0; i < leaf_vector_len; ++i) { - vec[i] = static_cast(leaf_vector[i]); + vec[i] = *static_cast(leaf_vector[i]); } builder->SetLeafVectorNode(node_key, vec); API_END(); } -int TreeliteCreateModelBuilder(int num_feature, - int num_output_group, - int random_forest_flag, - ModelBuilderHandle* out) { +int TreeliteCreateModelBuilder( + int num_feature, int num_output_group, int random_forest_flag, const char* threshold_type, + const char* leaf_output_type, ModelBuilderHandle* out) { API_BEGIN(); std::unique_ptr builder{new frontend::ModelBuilder( - num_feature, - num_output_group, - (random_forest_flag != 0))}; + num_feature, num_output_group, (random_forest_flag != 0), typeinfo_table.at(threshold_type), + typeinfo_table.at(leaf_output_type))}; *out = static_cast(builder.release()); API_END(); } @@ -484,7 +491,7 @@ int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char* name, const char* value) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object"; builder->SetModelParam(name, value); API_END(); @@ -500,9 +507,9 @@ int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index) { API_BEGIN(); - auto model_builder = static_cast(handle); + auto* model_builder = static_cast(handle); CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object"; - auto tree_builder = static_cast(tree_builder_handle); + auto* tree_builder = static_cast(tree_builder_handle); CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object"; return model_builder->InsertTree(tree_builder, index); API_END(); @@ -511,9 +518,9 @@ int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out) { API_BEGIN(); - auto model_builder = static_cast(handle); + auto* model_builder = static_cast(handle); CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object"; - auto tree_builder = model_builder->GetTree(index); + auto* tree_builder = model_builder->GetTree(index); CHECK(tree_builder) << "Detected dangling reference to deleted TreeBuilder object"; *out = static_cast(tree_builder); API_END(); @@ -521,7 +528,7 @@ int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object"; builder->DeleteTree(index); API_END(); @@ -530,11 +537,10 @@ int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index) { int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle* out) { API_BEGIN(); - auto builder = static_cast(handle); + auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object"; std::unique_ptr model{new Model()}; builder->CommitModel(model.get()); *out = static_cast(model.release()); API_END(); } -#endif diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index 0c365d3d..3592e137 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -120,11 +120,24 @@ Value::Value() : handle_(nullptr), type_(TypeInfo::kInvalid) {} template Value Value::Create(T init_value) { - Value val; + Value value; std::unique_ptr ptr = std::make_unique(init_value); - val.handle_.reset(ptr.release()); - val.type_ = InferTypeInfoOf(); - return val; + value.handle_.reset(ptr.release()); + value.type_ = InferTypeInfoOf(); + return value; +} + +Value +Value::Create(const void* init_value, TypeInfo type) { + Value value; + CHECK(type != TypeInfo::kInvalid) << "Type must be valid"; + value.type_ = type; + value.Dispatch([init_value](auto& value_handle) { + using T = std::remove_reference_t; + T t = *static_cast(init_value); + value_handle = t; + }); + return value; } template diff --git a/src/optable.cc b/src/tables.cc similarity index 56% rename from src/optable.cc rename to src/tables.cc index cc96ac94..603c3f1f 100644 --- a/src/optable.cc +++ b/src/tables.cc @@ -1,8 +1,8 @@ /*! * Copyright (c) 2017-2020 by Contributors - * \file optable.cc + * \file tables.cc * \author Hyunsu Cho - * \brief Conversion table from string to Operator + * \brief Conversion tables to obtain Operator and TypeInfo from strings */ #include @@ -17,4 +17,10 @@ const std::unordered_map optable{ {">=", Operator::kGE} }; +const std::unordered_map typeinfo_table{ + {"uint32", TypeInfo::kUInt32}, + {"float32", TypeInfo::kFloat32}, + {"float64", TypeInfo::kFloat64} +}; + } // namespace treelite From 34c2231b16c15abad6d402ffbdbcabaf015c7376 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 12:39:35 -0700 Subject: [PATCH 05/38] Fix lint checks --- include/treelite/frontend_impl.h | 6 ++++-- include/treelite/tree.h | 2 ++ include/treelite/tree_impl.h | 4 ++-- src/c_api/c_api.cc | 2 +- src/compiler/ast/builder.h | 2 +- src/frontend/builder.cc | 8 ++++---- src/frontend/lightgbm.cc | 15 ++++++++++----- 7 files changed, 24 insertions(+), 15 deletions(-) diff --git a/include/treelite/frontend_impl.h b/include/treelite/frontend_impl.h index 4d5fb1c8..da6e15a1 100644 --- a/include/treelite/frontend_impl.h +++ b/include/treelite/frontend_impl.h @@ -8,13 +8,15 @@ #ifndef TREELITE_FRONTEND_IMPL_H_ #define TREELITE_FRONTEND_IMPL_H_ +#include + namespace treelite { namespace frontend { template inline auto Value::Dispatch(Func func) { - switch(type_) { + switch (type_) { case TypeInfo::kUInt32: return func(Get()); case TypeInfo::kFloat32: @@ -31,7 +33,7 @@ Value::Dispatch(Func func) { template inline auto Value::Dispatch(Func func) const { - switch(type_) { + switch (type_) { case TypeInfo::kUInt32: return func(Get()); case TypeInfo::kFloat32: diff --git a/include/treelite/tree.h b/include/treelite/tree.h index b1087f81..c33037ec 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -468,6 +469,7 @@ struct Model { ModelType type_; TypeInfo threshold_type_; TypeInfo leaf_output_type_; + public: template inline static ModelType InferModelTypeOf(); diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 7c7131fd..4aaa03a7 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -914,7 +914,7 @@ Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { template inline auto Model::Dispatch(Func func) { - switch(type_) { + switch (type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat32ThresholdFloat32LeafOutput: @@ -933,7 +933,7 @@ Model::Dispatch(Func func) { template inline auto Model::Dispatch(Func func) const { - switch(type_) { + switch (type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: return func(GetImpl()); case ModelType::kFloat32ThresholdFloat32LeafOutput: diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ba0c0652..b801f8af 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -427,7 +427,7 @@ int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key) { } int TreeliteTreeBuilderSetNumericalTestNode( - TreeBuilderHandle handle,int node_key, unsigned feature_id, const char* opname, + TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname, ValueHandle threshold, int default_left, int left_child_key, int right_child_key) { API_BEGIN(); auto* builder = static_cast(handle); diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index 00a75385..7552c6bd 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -70,7 +70,7 @@ class ASTBuilder { } private: - friend bool treelite::compiler::fold_code(ASTNode*, CodeFoldingContext*, ASTBuilder*); + friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*, ASTBuilder*); template NodeType* AddNode(ASTNode* parent, Args&& ...args) { diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index 3592e137..fa838093 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -345,10 +345,10 @@ ModelBuilder::InsertTree(TreeBuilder* tree_builder, int index) { } if (tree_builder->pimpl_->tree.leaf_output_type != this->pimpl_->leaf_output_type) { LOG(FATAL) - << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " - << "member trees to use " << TypeInfoToString(this->pimpl_->leaf_output_type) - << " type for leaf outputs whereas the tree is using " - << TypeInfoToString(tree_builder->pimpl_->tree.leaf_output_type); + << "InsertTree: cannot insert the tree into the ensemble, because the ensemble requires all " + << "member trees to use " << TypeInfoToString(this->pimpl_->leaf_output_type) + << " type for leaf outputs whereas the tree is using " + << TypeInfoToString(tree_builder->pimpl_->tree.leaf_output_type); return -1; } diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index 99e75e7e..d94e3775 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -466,7 +466,8 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { CHECK(num_class >= 0 && num_class == model_handle.num_output_group) << "Ill-formed LightGBM model file: not a valid multiclass objective"; - std::strncpy(model_handle.param.pred_transform, "softmax", sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "softmax", + sizeof(model_handle.param.pred_transform)); } else if (obj_name_ == "multiclassova") { // validate num_class and alpha parameters int num_class = -1; @@ -489,7 +490,8 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { && alpha > 0.0f) << "Ill-formed LightGBM model file: not a valid multiclassova objective"; - std::strncpy(model_handle.param.pred_transform, "multiclass_ova", sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "multiclass_ova", + sizeof(model_handle.param.pred_transform)); model_handle.param.sigmoid_alpha = alpha; } else if (obj_name_ == "binary") { // validate alpha parameter @@ -506,16 +508,19 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { CHECK_GT(alpha, 0.0f) << "Ill-formed LightGBM model file: not a valid binary objective"; - std::strncpy(model_handle.param.pred_transform, "sigmoid", sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "sigmoid", + sizeof(model_handle.param.pred_transform)); model_handle.param.sigmoid_alpha = alpha; } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") { - std::strncpy(model_handle.param.pred_transform, "sigmoid", sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "sigmoid", + sizeof(model_handle.param.pred_transform)); model_handle.param.sigmoid_alpha = 1.0f; } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") { std::strncpy(model_handle.param.pred_transform, "logarithm_one_plus_exp", sizeof(model_handle.param.pred_transform)); } else { - std::strncpy(model_handle.param.pred_transform, "identity", sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle.param.pred_transform, "identity", + sizeof(model_handle.param.pred_transform)); } // traverse trees From 86104d69c33f5fdccb9591f1fc54bccf00c51473 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 14:10:39 -0700 Subject: [PATCH 06/38] Add a missing header --- include/treelite/base.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/treelite/base.h b/include/treelite/base.h index bc0f16cd..3c9a89ca 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -8,6 +8,7 @@ #define TREELITE_BASE_H_ #include +#include #include #include #include From 77477e7aef1641ccb360768ab4071f1290e66050 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 14:49:53 -0700 Subject: [PATCH 07/38] Perform round-trip serialization test with all possible types --- include/treelite/tree_impl.h | 10 ++- tests/cpp/test_serializer.cc | 148 +++++++++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 37 deletions(-) diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 4aaa03a7..595bf9e1 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -828,7 +829,10 @@ Model::GetModelType() const { template inline ModelType Model::InferModelTypeOf() { - const char* error_msg = "Unsupported combination of ThresholdType and LeafOutputType"; + const std::string error_msg + = std::string("Unsupported combination of ThresholdType (") + + TypeInfoToString(InferTypeInfoOf()) + ") and LeafOutputType (" + + TypeInfoToString(InferTypeInfoOf()) + ")"; static_assert(std::is_same::value || std::is_same::value, "ThresholdType should be either float32 or float64"); @@ -878,7 +882,7 @@ Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { auto error_leaf_output_type = [threshold_type, leaf_output_type]() { std::ostringstream oss; oss << "Cannot use leaf output type " << treelite::TypeInfoToString(leaf_output_type) - << "with threshold type " << treelite::TypeInfoToString(threshold_type); + << " with threshold type " << treelite::TypeInfoToString(threshold_type); return oss.str(); }; switch (threshold_type) { @@ -898,7 +902,7 @@ Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { case treelite::TypeInfo::kUInt32: return treelite::Model::Create(); case treelite::TypeInfo::kFloat64: - return treelite::Model::Create(); + return treelite::Model::Create(); default: throw std::runtime_error(error_leaf_output_type()); break; diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 37f691e1..662b981e 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -10,6 +10,7 @@ #include #include #include +#include namespace { @@ -32,20 +33,23 @@ inline void TestRoundTrip(treelite::Model* model) { namespace treelite { -TEST(PyBufferInterfaceRoundTrip, TreeStump) { +template +void PyBufferInterfaceRoundTrip_TreeStump() { + TypeInfo threshold_type = InferTypeInfoOf(); + TypeInfo leaf_output_type = InferTypeInfoOf(); std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ - new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); - tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetRootNode(0); - tree->SetLeafNode(1, frontend::Value::Create(-1.0f)); - tree->SetLeafNode(2, frontend::Value::Create(1.0f)); + tree->SetLeafNode(1, frontend::Value::Create(-1)); + tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -53,22 +57,37 @@ TEST(PyBufferInterfaceRoundTrip, TreeStump) { TestRoundTrip(model.get()); } -TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { +TEST(PyBufferInterfaceRoundTrip, TreeStump) { + PyBufferInterfaceRoundTrip_TreeStump(); + PyBufferInterfaceRoundTrip_TreeStump(); + PyBufferInterfaceRoundTrip_TreeStump(); + PyBufferInterfaceRoundTrip_TreeStump(); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStump()), std::runtime_error); +} + +template +void PyBufferInterfaceRoundTrip_TreeStumpLeafVec() { + TypeInfo threshold_type = InferTypeInfoOf(); + TypeInfo leaf_output_type = InferTypeInfoOf(); std::unique_ptr builder{ - new frontend::ModelBuilder(2, 2, true, TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::ModelBuilder(2, 2, true, threshold_type, leaf_output_type) }; std::unique_ptr tree{ - new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); - tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetRootNode(0); - tree->SetLeafVectorNode(1, {frontend::Value::Create(-1.0f), - frontend::Value::Create(1.0f)}); - tree->SetLeafVectorNode(2, {frontend::Value::Create(1.0f), - frontend::Value::Create(-1.0f)}); + tree->SetLeafVectorNode(1, {frontend::Value::Create(-1), + frontend::Value::Create(1)}); + tree->SetLeafVectorNode(2, {frontend::Value::Create(1), + frontend::Value::Create(-1)}); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -76,20 +95,40 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { TestRoundTrip(model.get()); } -TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { +TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { + PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); + PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); + PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); + PyBufferInterfaceRoundTrip_TreeStumpLeafVec(); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpLeafVec()), + std::runtime_error); +} + +template +void PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit() { + TypeInfo threshold_type = InferTypeInfoOf(); + TypeInfo leaf_output_type = InferTypeInfoOf(); std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ - new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::TreeBuilder(threshold_type, leaf_output_type) }; tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); tree->SetCategoricalTestNode(0, 0, {0, 1}, true, 1, 2); tree->SetRootNode(0); - tree->SetLeafNode(1, frontend::Value::Create(-1.0f)); - tree->SetLeafNode(2, frontend::Value::Create(1.0f)); + tree->SetLeafNode(1, frontend::Value::Create(-1)); + tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); std::unique_ptr model{new Model()}; @@ -97,27 +136,47 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { TestRoundTrip(model.get()); } -TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { +TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { + PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); + PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); + PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); + PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit(); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), + std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit()), + std::runtime_error); +} + +template +void PyBufferInterfaceRoundTrip_TreeDepth2() { + TypeInfo threshold_type = InferTypeInfoOf(); + TypeInfo leaf_output_type = InferTypeInfoOf(); std::unique_ptr builder{ - new frontend::ModelBuilder(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::ModelBuilder(2, 1, false, threshold_type, leaf_output_type) }; builder->SetModelParam("pred_transform", "sigmoid"); builder->SetModelParam("global_bias", "0.5"); for (int tree_id = 0; tree_id < 2; ++tree_id) { std::unique_ptr tree{ - new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::TreeBuilder(threshold_type, leaf_output_type) }; for (int i = 0; i < 7; ++i) { tree->CreateNode(i); } - tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0.0f), true, 1, 2); + tree->SetNumericalTestNode(0, 0, "<", frontend::Value::Create(0), true, 1, 2); tree->SetCategoricalTestNode(1, 0, {0, 1}, true, 3, 4); tree->SetCategoricalTestNode(2, 1, {0}, true, 5, 6); tree->SetRootNode(0); - tree->SetLeafNode(3, frontend::Value::Create(-2.0f)); - tree->SetLeafNode(4, frontend::Value::Create(1.0f)); - tree->SetLeafNode(5, frontend::Value::Create(-1.0f)); - tree->SetLeafNode(6, frontend::Value::Create(2.0f)); + tree->SetLeafNode(3, frontend::Value::Create(-2)); + tree->SetLeafNode(4, frontend::Value::Create(1)); + tree->SetLeafNode(5, frontend::Value::Create(-1)); + tree->SetLeafNode(6, frontend::Value::Create(2)); builder->InsertTree(tree.get()); } @@ -126,14 +185,29 @@ TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { TestRoundTrip(model.get()); } -TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { +TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { + PyBufferInterfaceRoundTrip_TreeDepth2(); + PyBufferInterfaceRoundTrip_TreeDepth2(); + PyBufferInterfaceRoundTrip_TreeDepth2(); + PyBufferInterfaceRoundTrip_TreeDepth2(); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); + ASSERT_THROW((PyBufferInterfaceRoundTrip_TreeDepth2()), std::runtime_error); +} + +template +void PyBufferInterfaceRoundTrip_DeepFullTree() { + TypeInfo threshold_type = InferTypeInfoOf(); + TypeInfo leaf_output_type = InferTypeInfoOf(); const int depth = 19; std::unique_ptr builder{ - new frontend::ModelBuilder(3, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::ModelBuilder(3, 1, false, threshold_type, leaf_output_type) }; std::unique_ptr tree{ - new frontend::TreeBuilder(TypeInfo::kFloat32, TypeInfo::kFloat32) + new frontend::TreeBuilder(threshold_type, leaf_output_type) }; for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { @@ -144,12 +218,11 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { for (int level = 0; level <= depth; ++level) { for (int i = 0; i < (1 << level); ++i) { const int nid = (1 << level) - 1 + i; - const float leaf_value = 0.5f; if (level == depth) { - tree->SetLeafNode(nid, frontend::Value::Create(leaf_value)); + tree->SetLeafNode(nid, frontend::Value::Create(1)); } else { - tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0.0f), - true, 2 * nid + 1, 2 * nid + 2); + tree->SetNumericalTestNode(nid, (level % 2), "<", frontend::Value::Create(0), + true, 2 * nid + 1, 2 * nid + 2); } } } @@ -161,4 +234,11 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { TestRoundTrip(model.get()); } +TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { + PyBufferInterfaceRoundTrip_DeepFullTree(); + PyBufferInterfaceRoundTrip_DeepFullTree(); + PyBufferInterfaceRoundTrip_DeepFullTree(); + PyBufferInterfaceRoundTrip_DeepFullTree(); +} + } // namespace treelite From b846f6af41f3ec8eb4e781931b9072aba04d2bd5 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 16:06:14 -0700 Subject: [PATCH 08/38] Update native code template --- include/treelite/base.h | 2 +- include/treelite/tree.h | 2 + include/treelite/tree_impl.h | 10 +++ src/CMakeLists.txt | 1 + src/compiler/ast_native.cc | 85 +++++++++++++++------- src/compiler/native/code_folder_template.h | 4 +- src/compiler/native/header_template.h | 6 +- src/compiler/native/main_template.h | 10 +-- src/compiler/native/pred_transform.h | 83 +++++++++++++-------- src/compiler/native/qnode_template.h | 10 +-- src/compiler/native/typeinfo_ctypes.h | 85 ++++++++++++++++++++++ 11 files changed, 223 insertions(+), 75 deletions(-) create mode 100644 src/compiler/native/typeinfo_ctypes.h diff --git a/include/treelite/base.h b/include/treelite/base.h index 3c9a89ca..25e34d3a 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -64,7 +64,7 @@ inline std::string OpName(Operator op) { } /*! - * \brief get string representation of type info + * \brief Get string representation of type info * \param info a type info * \return string representation */ diff --git a/include/treelite/tree.h b/include/treelite/tree.h index c33037ec..e8171f1a 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -481,6 +481,8 @@ struct Model { template inline const ModelImpl& GetImpl() const; inline ModelType GetModelType() const; + inline TypeInfo GetThresholdType() const; + inline TypeInfo GetLeafOutputType() const; template inline auto Dispatch(Func func); template diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 595bf9e1..b2f421dd 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -826,6 +826,16 @@ Model::GetModelType() const { return type_; } +inline TypeInfo +Model::GetThresholdType() const { + return threshold_type_; +} + +inline TypeInfo +Model::GetLeafOutputType() const { + return leaf_output_type_; +} + template inline ModelType Model::InferModelTypeOf() { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 244f78b2..fc09ce9e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -76,6 +76,7 @@ target_sources(objtreelite compiler/native/main_template.h compiler/native/pred_transform.h compiler/native/qnode_template.h + compiler/native/typeinfo_ctypes.h compiler/ast_native.cc compiler/compiler.cc compiler/failsafe.cc diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index 1ebffb2a..dcae797b 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -19,6 +19,7 @@ #include "./native/header_template.h" #include "./native/qnode_template.h" #include "./native/code_folder_template.h" +#include "./native/typeinfo_ctypes.h" #include "./common/format_util.h" #include "./common/code_folding_util.h" #include "./common/categorical_bitmap.h" @@ -63,8 +64,7 @@ class ASTNativeCompiler : public Compiler { ASTBuilder builder; builder.BuildAST(model); - if (builder.FoldCode(param.code_folding_req) - || param.quantize > 0) { + if (builder.FoldCode(param.code_folding_req) || param.quantize > 0) { // is_categorical[i] : is i-th feature categorical? array_is_categorical_ = RenderIsCategoricalArray(builder.GenerateIsCategoricalArray()); @@ -188,6 +188,10 @@ class ASTNativeCompiler : public Compiler { void HandleMainNode(const MainNode* node, const std::string& dest, size_t indent) { + const std::string threshold_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); + const std::string leaf_output_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); const char* get_num_output_group_function_signature = "size_t get_num_output_group(void)"; const char* get_num_feature_function_signature @@ -198,11 +202,12 @@ class ASTNativeCompiler : public Compiler { = "float get_sigmoid_alpha(void)"; const char* get_global_bias_function_signature = "float get_global_bias(void)"; - const char* predict_function_signature + const std::string predict_function_signature = (num_output_group_ > 1) ? - "size_t predict_multiclass(union Entry* data, int pred_margin, " - "float* result)" - : "float predict(union Entry* data, int pred_margin)"; + fmt::format("size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)", + leaf_output_type) + : fmt::format("{} predict(union Entry* data, int pred_margin)", + leaf_output_type); if (!array_is_categorical_.empty()) { array_is_categorical_ @@ -245,27 +250,29 @@ class ASTNativeCompiler : public Compiler { "get_global_bias_function_signature"_a = get_global_bias_function_signature, "predict_function_signature"_a = predict_function_signature, - "threshold_type"_a = (param.quantize > 0 ? "int" : "float")), + "threshold_type"_a = threshold_type, + "threshold_type_Node"_a = (param.quantize > 0 ? std::string("int") : threshold_type)), indent); CHECK_EQ(node->children.size(), 1); WalkAST(node->children[0], dest, indent + 2); const std::string optional_average_field - = (node->average_result) ? fmt::format(" / {}", node->num_tree) - : std::string(""); + = (node->average_result) ? fmt::format(" / {}", node->num_tree) : std::string(""); if (num_output_group_ > 1) { AppendToBuffer(dest, fmt::format(native::main_end_multiclass_template, "num_output_group"_a = num_output_group_, "optional_average_field"_a = optional_average_field, - "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)), + "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias), + "leaf_output_type"_a = leaf_output_type), indent); } else { AppendToBuffer(dest, fmt::format(native::main_end_template, "optional_average_field"_a = optional_average_field, - "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias)), + "global_bias"_a = common_util::ToStringHighPrecision(node->global_bias), + "leaf_output_type"_a = leaf_output_type), indent); } } @@ -274,17 +281,22 @@ class ASTNativeCompiler : public Compiler { void HandleACNode(const AccumulatorContextNode* node, const std::string& dest, size_t indent) { + const std::string leaf_output_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); if (num_output_group_ > 1) { AppendToBuffer(dest, - fmt::format("float sum[{num_output_group}] = {{0.0f}};\n" + fmt::format("{leaf_output_type} sum[{num_output_group}] = {{0}};\n" "unsigned int tmp;\n" "int nid, cond, fid; /* used for folded subtrees */\n", - "num_output_group"_a = num_output_group_), indent); + "num_output_group"_a = num_output_group_, + "leaf_output_type"_a = leaf_output_type), indent); } else { AppendToBuffer(dest, - "float sum = 0.0f;\n" - "unsigned int tmp;\n" - "int nid, cond, fid; /* used for folded subtrees */\n", indent); + fmt::format("{leaf_output_type} sum = ({leaf_output_type})0;\n" + "unsigned int tmp;\n" + "int nid, cond, fid; /* used for folded subtrees */\n", + "leaf_output_type"_a = leaf_output_type), + indent); } for (ASTNode* child : node->children) { WalkAST(child, dest, indent); @@ -344,6 +356,8 @@ class ASTNativeCompiler : public Compiler { int indent) { const int unit_id = node->unit_id; const std::string new_file = fmt::format("tu{}.c", unit_id); + const std::string leaf_output_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); std::string unit_function_name, unit_function_signature, unit_function_call_signature; @@ -351,15 +365,18 @@ class ASTNativeCompiler : public Compiler { unit_function_name = fmt::format("predict_margin_multiclass_unit{}", unit_id); unit_function_signature - = fmt::format("void {}(union Entry* data, float* result)", - unit_function_name); + = fmt::format("void {function_name}(union Entry* data, {leaf_output_type}* result)", + "function_name"_a = unit_function_name, + "leaf_output_type"_a = leaf_output_type); unit_function_call_signature = fmt::format("{}(data, sum);\n", unit_function_name); } else { unit_function_name = fmt::format("predict_margin_unit{}", unit_id); unit_function_signature - = fmt::format("float {}(union Entry* data)", unit_function_name); + = fmt::format("{leaf_output_type} {function_name}(union Entry* data)", + "function_name"_a = unit_function_name, + "leaf_output_type"_a = leaf_output_type); unit_function_call_signature = fmt::format("sum += {}(data);\n", unit_function_name); } @@ -386,6 +403,8 @@ class ASTNativeCompiler : public Compiler { void HandleQNode(const QuantizerNode* node, const std::string& dest, size_t indent) { + const std::string threshold_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); /* render arrays needed to convert feature values into bin indices */ std::string array_threshold, array_th_begin, array_th_len; // threshold[] : list of all thresholds that occur at least once in the @@ -425,16 +444,21 @@ class ASTNativeCompiler : public Compiler { if (!array_threshold.empty() && !array_th_begin.empty() && !array_th_len.empty()) { PrependToBuffer(dest, fmt::format(native::qnode_template, - "total_num_threshold"_a = total_num_threshold), 0); + "total_num_threshold"_a = total_num_threshold, + "threshold_type"_a = threshold_type), + 0); AppendToBuffer(dest, fmt::format(native::quantize_loop_template, "num_feature"_a = num_feature_), indent); } if (!array_threshold.empty()) { PrependToBuffer(dest, - fmt::format("static const float threshold[] = {{\n" + fmt::format("static const {threshold_type} threshold[] = {{\n" "{array_threshold}\n" - "}};\n", "array_threshold"_a = array_threshold), 0); + "}};\n", + "array_threshold"_a = array_threshold, + "threshold_type"_a = threshold_type), + 0); } if (!array_th_begin.empty()) { PrependToBuffer(dest, @@ -625,6 +649,8 @@ class ASTNativeCompiler : public Compiler { template inline std::string RenderOutputStatement(const OutputNode* node) { + const std::string leaf_output_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); std::string output_statement; if (num_output_group_ > 1) { if (node->is_vector) { @@ -633,21 +659,24 @@ class ASTNativeCompiler : public Compiler { << "Ill-formed model: leaf vector must be of length [num_output_group]"; for (int group_id = 0; group_id < num_output_group_; ++group_id) { output_statement - += fmt::format("sum[{group_id}] += (float){output};\n", + += fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n", "group_id"_a = group_id, - "output"_a = common_util::ToStringHighPrecision(node->vector[group_id])); + "output"_a = common_util::ToStringHighPrecision(node->vector[group_id]), + "leaf_output_type"_a = leaf_output_type); } } else { // multi-class classification with gradient boosted trees output_statement - = fmt::format("sum[{group_id}] += (float){output};\n", + = fmt::format("sum[{group_id}] += ({leaf_output_type}){output};\n", "group_id"_a = node->tree_id % num_output_group_, - "output"_a = common_util::ToStringHighPrecision(node->scalar)); + "output"_a = common_util::ToStringHighPrecision(node->scalar), + "leaf_output_type"_a = leaf_output_type); } } else { output_statement - = fmt::format("sum += (float){output};\n", - "output"_a = common_util::ToStringHighPrecision(node->scalar)); + = fmt::format("sum += ({leaf_output_type}){output};\n", + "output"_a = common_util::ToStringHighPrecision(node->scalar), + "leaf_output_type"_a = leaf_output_type); } return output_statement; } diff --git a/src/compiler/native/code_folder_template.h b/src/compiler/native/code_folder_template.h index 628ccd93..42924186 100644 --- a/src/compiler/native/code_folder_template.h +++ b/src/compiler/native/code_folder_template.h @@ -12,7 +12,7 @@ namespace treelite { namespace compiler { namespace native { -const char* eval_loop_template = +const char* const eval_loop_template = R"TREELITETEMPLATE( nid = 0; while (nid >= 0) {{ /* negative nid implies leaf */ @@ -31,7 +31,7 @@ while (nid >= 0) {{ /* negative nid implies leaf */ {output_switch_statement} )TREELITETEMPLATE"; -const char* eval_loop_template_without_categorical_feature = +const char* const eval_loop_template_without_categorical_feature = R"TREELITETEMPLATE( nid = 0; while (nid >= 0) {{ /* negative nid implies leaf */ diff --git a/src/compiler/native/header_template.h b/src/compiler/native/header_template.h index 06b1061d..a822f54f 100644 --- a/src/compiler/native/header_template.h +++ b/src/compiler/native/header_template.h @@ -12,7 +12,7 @@ namespace treelite { namespace compiler { namespace native { -const char* header_template = +const char* const header_template = R"TREELITETEMPLATE( #include #include @@ -29,14 +29,14 @@ R"TREELITETEMPLATE( union Entry {{ int missing; - float fvalue; + {threshold_type} fvalue; int qvalue; }}; struct Node {{ uint8_t default_left; unsigned int split_index; - {threshold_type} threshold; + {threshold_type_Node} threshold; int left_child; int right_child; }}; diff --git a/src/compiler/native/main_template.h b/src/compiler/native/main_template.h index 5f04cba8..a3176e65 100644 --- a/src/compiler/native/main_template.h +++ b/src/compiler/native/main_template.h @@ -12,7 +12,7 @@ namespace treelite { namespace compiler { namespace native { -const char* main_start_template = +const char* const main_start_template = R"TREELITETEMPLATE( #include "header.h" @@ -42,10 +42,10 @@ R"TREELITETEMPLATE( {predict_function_signature} {{ )TREELITETEMPLATE"; -const char* main_end_multiclass_template = +const char* const main_end_multiclass_template = R"TREELITETEMPLATE( for (int i = 0; i < {num_output_group}; ++i) {{ - result[i] = sum[i]{optional_average_field} + (float)({global_bias}); + result[i] = sum[i]{optional_average_field} + ({leaf_output_type})({global_bias}); }} if (!pred_margin) {{ return pred_transform(result); @@ -55,9 +55,9 @@ R"TREELITETEMPLATE( }} )TREELITETEMPLATE"; // only for multiclass classification -const char* main_end_template = +const char* const main_end_template = R"TREELITETEMPLATE( - sum = sum{optional_average_field} + (float)({global_bias}); + sum = sum{optional_average_field} + ({leaf_output_type})({global_bias}); if (!pred_margin) {{ return pred_transform(sum); }} else {{ diff --git a/src/compiler/native/pred_transform.h b/src/compiler/native/pred_transform.h index 06dc1d2e..bfdf16fc 100644 --- a/src/compiler/native/pred_transform.h +++ b/src/compiler/native/pred_transform.h @@ -11,6 +11,7 @@ #include #include #include +#include "./typeinfo_ctypes.h" using namespace fmt::literals; @@ -21,91 +22,108 @@ namespace pred_transform { inline std::string identity(const Model& model) { return fmt::format( -R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ +R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ return margin; -}})TREELITETEMPLATE"); +}})TREELITETEMPLATE", +"threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType())); } inline std::string sigmoid(const Model& model) { const float alpha = model.GetParam().sigmoid_alpha; + const TypeInfo threshold_type = model.GetThresholdType(); CHECK_GT(alpha, 0.0f) << "sigmoid: alpha must be strictly positive"; return fmt::format( -R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ - const float alpha = (float){alpha}; - return 1.0f / (1 + expf(-alpha * margin)); +R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ + const {threshold_type} alpha = ({threshold_type}){alpha}; + return ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * margin)); }})TREELITETEMPLATE", - "alpha"_a = alpha); + "alpha"_a = alpha, + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type), + "exp"_a = native::CExpForTypeInfo(threshold_type)); } inline std::string exponential(const Model& model) { + const TypeInfo threshold_type = model.GetThresholdType(); return fmt::format( -R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ - return expf(margin); -}})TREELITETEMPLATE"); +R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ + return {exp}(margin); +}})TREELITETEMPLATE", + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type), + "exp"_a = native::CExpForTypeInfo(threshold_type)); } inline std::string logarithm_one_plus_exp(const Model& model) { + const TypeInfo threshold_type = model.GetThresholdType(); return fmt::format( -R"TREELITETEMPLATE(static inline float pred_transform(float margin) {{ - return log1pf(expf(margin)); -}})TREELITETEMPLATE"); +R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{ + return {log1p}({exp}(margin)); +}})TREELITETEMPLATE", + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type), + "exp"_a = native::CExpForTypeInfo(threshold_type), + "log1p"_a = native::CLog1PForTypeInfo(threshold_type)); } inline std::string identity_multiclass(const Model& model) { - CHECK(model.GetNumOutputGroup() > 1) + CHECK_GT(model.GetNumOutputGroup(), 1) << "identity_multiclass: model is not a proper multi-class classifier"; return fmt::format( -R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ +R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ return {num_class}; }})TREELITETEMPLATE", - "num_class"_a = model.GetNumOutputGroup()); + "num_class"_a = model.GetNumOutputGroup(), + "threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType())); } inline std::string max_index(const Model& model) { - CHECK(model.GetNumOutputGroup() > 1) + CHECK_GT(model.GetNumOutputGroup(), 1) << "max_index: model is not a proper multi-class classifier"; + const TypeInfo threshold_type = model.GetThresholdType(); return fmt::format( -R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ +R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ const int num_class = {num_class}; int max_index = 0; - float max_margin = pred[0]; + {threshold_type} max_margin = pred[0]; for (int k = 1; k < num_class; ++k) {{ if (pred[k] > max_margin) {{ max_margin = pred[k]; max_index = k; }} }} - pred[0] = (float)max_index; + pred[0] = ({threshold_type})max_index; return 1; }})TREELITETEMPLATE", - "num_class"_a = model.GetNumOutputGroup()); + "num_class"_a = model.GetNumOutputGroup(), + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type)); } inline std::string softmax(const Model& model) { - CHECK(model.GetNumOutputGroup() > 1) + CHECK_GT(model.GetNumOutputGroup(), 1) << "softmax: model is not a proper multi-class classifier"; + const TypeInfo threshold_type = model.GetThresholdType(); return fmt::format( -R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ +R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ const int num_class = {num_class}; - float max_margin = pred[0]; + {threshold_type} max_margin = pred[0]; double norm_const = 0.0; - float t; + {threshold_type} t; for (int k = 1; k < num_class; ++k) {{ if (pred[k] > max_margin) {{ max_margin = pred[k]; }} }} for (int k = 0; k < num_class; ++k) {{ - t = expf(pred[k] - max_margin); + t = {exp}(pred[k] - max_margin); norm_const += t; pred[k] = t; }} for (int k = 0; k < num_class; ++k) {{ - pred[k] /= (float)norm_const; + pred[k] /= ({threshold_type})norm_const; }} return (size_t)num_class; }})TREELITETEMPLATE", - "num_class"_a = model.GetNumOutputGroup()); + "num_class"_a = model.GetNumOutputGroup(), + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type), + "exp"_a = native::CExpForTypeInfo(threshold_type)); } inline std::string multiclass_ova(const Model& model) { @@ -113,17 +131,20 @@ inline std::string multiclass_ova(const Model& model) { << "multiclass_ova: model is not a proper multi-class classifier"; const int num_class = model.GetNumOutputGroup(); const float alpha = model.GetParam().sigmoid_alpha; + const TypeInfo threshold_type = model.GetThresholdType(); CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive"; return fmt::format( -R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ - const float alpha = (float){alpha}; +R"TREELITETEMPLATE(static inline size_t pred_transform({threshold_type}* pred) {{ + const {threshold_type} alpha = ({threshold_type}){alpha}; const int num_class = {num_class}; for (int k = 0; k < num_class; ++k) {{ - pred[k] = 1.0f / (1.0f + expf(-alpha * pred[k])); + pred[k] = ({threshold_type})(1) / (({threshold_type})(1) + {exp}(-alpha * pred[k])); }} return (size_t)num_class; }})TREELITETEMPLATE", - "num_class"_a = model.GetNumOutputGroup(), "alpha"_a = alpha); + "num_class"_a = model.GetNumOutputGroup(), "alpha"_a = alpha, + "threshold_type"_a = native::TypeInfoToCTypeString(threshold_type), + "exp"_a = native::CExpForTypeInfo(threshold_type)); } } // namespace pred_transform diff --git a/src/compiler/native/qnode_template.h b/src/compiler/native/qnode_template.h index a2a60b93..7f046874 100644 --- a/src/compiler/native/qnode_template.h +++ b/src/compiler/native/qnode_template.h @@ -12,7 +12,7 @@ namespace treelite { namespace compiler { namespace native { -const char* qnode_template = +const char* const qnode_template = R"TREELITETEMPLATE( #include @@ -22,14 +22,14 @@ R"TREELITETEMPLATE( * \param fid feature identifier * \return bin index corresponding to given feature value */ -static inline int quantize(float val, unsigned fid) {{ +static inline int quantize({threshold_type} val, unsigned fid) {{ const size_t offset = th_begin[fid]; - const float* array = &threshold[offset]; + const {threshold_type}* array = &threshold[offset]; int len = th_len[fid]; int low = 0; int high = len; int mid; - float mval; + {threshold_type} mval; // It is possible th_begin[i] == [total_num_threshold]. This means that // all features i, (i+1), ... are not used for any of the splits in the model. // So in this case, just return something @@ -57,7 +57,7 @@ static inline int quantize(float val, unsigned fid) {{ }} )TREELITETEMPLATE"; -const char* quantize_loop_template = +const char* const quantize_loop_template = R"TREELITETEMPLATE( for (int i = 0; i < {num_feature}; ++i) {{ if (data[i].missing != -1 && !is_categorical[i]) {{ diff --git a/src/compiler/native/typeinfo_ctypes.h b/src/compiler/native/typeinfo_ctypes.h new file mode 100644 index 00000000..ef41f44c --- /dev/null +++ b/src/compiler/native/typeinfo_ctypes.h @@ -0,0 +1,85 @@ +// +// Created by Philip Hyunsu Cho on 8/31/20. +// + +#ifndef TREELITE_COMPILER_NATIVE_TYPEINFO_CTYPES_H_ +#define TREELITE_COMPILER_NATIVE_TYPEINFO_CTYPES_H_ + +#include +#include + +namespace treelite { +namespace compiler { +namespace native { + +/*! + * \brief Get string representation of the C type that's equivalent to the given type info + * \param info a type info + * \return string representation + */ +inline std::string TypeInfoToCTypeString(TypeInfo type) { + switch (type) { + case TypeInfo::kInvalid: + throw std::runtime_error("Invalid type"); + return ""; + case TypeInfo::kUInt32: + return "uint32_t"; + case TypeInfo::kFloat32: + return "float"; + case TypeInfo::kFloat64: + return "double"; + default: + throw std::runtime_error(std::string("Unrecognized type: ") + + std::to_string(static_cast(type))); + return ""; + } +} + +/*! + * \brief Look up the correct variant of exp() in C that should be used with a given type + * \param info a type info + * \return string representation + */ +inline std::string CExpForTypeInfo(TypeInfo type) { + switch (type) { + case TypeInfo::kInvalid: + case TypeInfo::kUInt32: + throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type)); + return ""; + case TypeInfo::kFloat32: + return "expf"; + case TypeInfo::kFloat64: + return "exp"; + default: + throw std::runtime_error(std::string("Unrecognized type: ") + + std::to_string(static_cast(type))); + return ""; + } +} + +/*! + * \brief Look up the correct variant of log1p() in C that should be used with a given type + * \param info a type info + * \return string representation + */ +inline std::string CLog1PForTypeInfo(TypeInfo type) { + switch (type) { + case TypeInfo::kInvalid: + case TypeInfo::kUInt32: + throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type)); + return ""; + case TypeInfo::kFloat32: + return "log1pf"; + case TypeInfo::kFloat64: + return "log1p"; + default: + throw std::runtime_error(std::string("Unrecognized type: ") + + std::to_string(static_cast(type))); + return ""; + } +} + +} // namespace native +} // namespace compiler +} // namespace treelite +#endif // TREELITE_COMPILER_NATIVE_TYPEINFO_CTYPES_H_ From 0f88e96ce89e0d21313b28572cb4589541397e1f Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 20:22:51 -0700 Subject: [PATCH 09/38] Move TypeInfo to a separate header + objtreelite_common --- include/treelite/base.h | 54 +------------------------- include/treelite/predictor.h | 1 + include/treelite/typeinfo.h | 73 +++++++++++++++++++++++++++++++++++ src/CMakeLists.txt | 4 +- src/{tables.cc => optable.cc} | 12 ++---- src/typeinfo.cc | 22 +++++++++++ 6 files changed, 104 insertions(+), 62 deletions(-) create mode 100644 include/treelite/typeinfo.h rename src/{tables.cc => optable.cc} (56%) create mode 100644 src/typeinfo.cc diff --git a/include/treelite/base.h b/include/treelite/base.h index 25e34d3a..748473f6 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -12,6 +12,7 @@ #include #include #include +#include "./typeinfo.h" namespace treelite { @@ -31,22 +32,9 @@ enum class Operator : int8_t { kGE, /*!< operator >= */ }; -/*! \brief Types used by thresholds and leaf outputs */ -enum class TypeInfo : uint8_t { - kInvalid = 0, - kUInt32 = 1, - kFloat32 = 2, - kFloat64 = 3 -}; -static_assert(std::is_same::type, uint8_t>::value, - "TypeInfo must use uint8_t as underlying type"); - /*! \brief conversion table from string to Operator, defined in tables.cc */ extern const std::unordered_map optable; -/*! \brief conversion table from string to TypeInfo, defined in tables.cc */ -extern const std::unordered_map typeinfo_table; - /*! * \brief get string representation of comparison operator * \param op comparison operator @@ -63,46 +51,6 @@ inline std::string OpName(Operator op) { } } -/*! - * \brief Get string representation of type info - * \param info a type info - * \return string representation - */ -inline std::string TypeInfoToString(treelite::TypeInfo type) { - switch (type) { - case treelite::TypeInfo::kInvalid: - return "invalid"; - case treelite::TypeInfo::kUInt32: - return "uint32"; - case treelite::TypeInfo::kFloat32: - return "float32"; - case treelite::TypeInfo::kFloat64: - return "float64"; - default: - throw std::runtime_error("Unrecognized type"); - return ""; - } -} - -/*! - * \brief Convert a template type into a type info - * \tparam template type to be converted - * \return TypeInfo corresponding to the template type arg - */ -template -inline TypeInfo InferTypeInfoOf() { - if (std::is_same::value) { - return TypeInfo::kUInt32; - } else if (std::is_same::value) { - return TypeInfo::kFloat32; - } else if (std::is_same::value) { - return TypeInfo::kFloat64; - } else { - throw std::runtime_error(std::string("Unrecognized Value type") + typeid(T).name()); - return TypeInfo::kInvalid; - } -} - /*! * \brief perform comparison between two float's using a comparsion operator * The comparison will be in the form [lhs] [op] [rhs]. diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 00222298..9f668cb8 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -9,6 +9,7 @@ #include #include +#include #include #include diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h new file mode 100644 index 00000000..5fb8b7c7 --- /dev/null +++ b/include/treelite/typeinfo.h @@ -0,0 +1,73 @@ +/*! + * Copyright (c) 2017-2020 by Contributors + * \file typeinfo.h + * \brief Defines TypeInfo class and utilities + * \author Hyunsu Cho + */ + +#ifndef TREELITE_TYPEINFO_H_ +#define TREELITE_TYPEINFO_H_ + +#include +#include +#include +#include + +namespace treelite { + +/*! \brief Types used by thresholds and leaf outputs */ +enum class TypeInfo : uint8_t { + kInvalid = 0, + kUInt32 = 1, + kFloat32 = 2, + kFloat64 = 3 +}; +static_assert(std::is_same::type, uint8_t>::value, + "TypeInfo must use uint8_t as underlying type"); + +/*! \brief conversion table from string to TypeInfo, defined in tables.cc */ +extern const std::unordered_map typeinfo_table; + +/*! + * \brief Get string representation of type info + * \param info a type info + * \return string representation + */ +inline std::string TypeInfoToString(treelite::TypeInfo type) { + switch (type) { + case treelite::TypeInfo::kInvalid: + return "invalid"; + case treelite::TypeInfo::kUInt32: + return "uint32"; + case treelite::TypeInfo::kFloat32: + return "float32"; + case treelite::TypeInfo::kFloat64: + return "float64"; + default: + throw std::runtime_error("Unrecognized type"); + return ""; + } +} + +/*! + * \brief Convert a template type into a type info + * \tparam template type to be converted + * \return TypeInfo corresponding to the template type arg + */ +template +inline TypeInfo InferTypeInfoOf() { + if (std::is_same::value) { + return TypeInfo::kUInt32; + } else if (std::is_same::value) { + return TypeInfo::kFloat32; + } else if (std::is_same::value) { + return TypeInfo::kFloat64; + } else { + throw std::runtime_error(std::string("Unrecognized Value type") + typeid(T).name()); + return TypeInfo::kInvalid; + } +} + +} // namespace treelite + +#endif // TREELITE_TYPEINFO_H_ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fc09ce9e..10e49af3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -88,7 +88,7 @@ target_sources(objtreelite annotator.cc data.cc filesystem.cc - tables.cc + optable.cc reference_serializer.cc ${PROJECT_SOURCE_DIR}/include/treelite/annotator.h ${PROJECT_SOURCE_DIR}/include/treelite/base.h @@ -121,9 +121,11 @@ target_sources(objtreelite_common c_api/c_api_error.cc c_api/c_api_error.h logging.cc + typeinfo.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_common.h ${PROJECT_SOURCE_DIR}/include/treelite/logging.h ${PROJECT_SOURCE_DIR}/include/treelite/math.h + ${PROJECT_SOURCE_DIR}/include/treelite/typeinfo.h ) msvc_use_static_runtime() diff --git a/src/tables.cc b/src/optable.cc similarity index 56% rename from src/tables.cc rename to src/optable.cc index 603c3f1f..af4e4581 100644 --- a/src/tables.cc +++ b/src/optable.cc @@ -1,10 +1,12 @@ /*! * Copyright (c) 2017-2020 by Contributors - * \file tables.cc + * \file optable.cc * \author Hyunsu Cho - * \brief Conversion tables to obtain Operator and TypeInfo from strings + * \brief Conversion tables to obtain Operator from string */ +#include +#include #include namespace treelite { @@ -17,10 +19,4 @@ const std::unordered_map optable{ {">=", Operator::kGE} }; -const std::unordered_map typeinfo_table{ - {"uint32", TypeInfo::kUInt32}, - {"float32", TypeInfo::kFloat32}, - {"float64", TypeInfo::kFloat64} -}; - } // namespace treelite diff --git a/src/typeinfo.cc b/src/typeinfo.cc new file mode 100644 index 00000000..a58b2904 --- /dev/null +++ b/src/typeinfo.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2017-2020 by Contributors + * \file typeinfo.cc + * \author Hyunsu Cho + * \brief Conversion tables to obtain TypeInfo from string + */ + +// Do not include other Treelite headers here, to minimize cross-header dependencies + +#include +#include +#include + +namespace treelite { + +const std::unordered_map typeinfo_table{ + {"uint32", TypeInfo::kUInt32}, + {"float32", TypeInfo::kFloat32}, + {"float64", TypeInfo::kFloat64} +}; + +} // namespace treelite From 81e9c7ca633881caaf318961de485c4437179412 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 20:32:00 -0700 Subject: [PATCH 10/38] Move data.cc to objtreelite_common; start new DMatrix design --- include/treelite/annotator.h | 4 ++-- include/treelite/data.h | 39 +++++++++++++++++++++++++++++++----- src/CMakeLists.txt | 4 ++-- src/annotator.cc | 6 +++--- src/c_api/c_api.cc | 18 ++++++++--------- src/{ => data}/data.cc | 14 ++++++------- 6 files changed, 57 insertions(+), 28 deletions(-) rename src/{ => data}/data.cc (89%) diff --git a/include/treelite/annotator.h b/include/treelite/annotator.h index accdc46a..9b106a8d 100644 --- a/include/treelite/annotator.h +++ b/include/treelite/annotator.h @@ -18,7 +18,7 @@ class BranchAnnotator { public: template void AnnotateImpl(const treelite::ModelImpl& model, - const treelite::DMatrix* dmat, int nthread, int verbose); + const treelite::LegacyDMatrix* dmat, int nthread, int verbose); /*! * \brief annotate branches in a given model using frequency patterns in the * training data. The annotation can be accessed through Get() method. @@ -27,7 +27,7 @@ class BranchAnnotator { * \param nthread number of threads to use * \param verbose whether to produce extra messages */ - void Annotate(const Model& model, const DMatrix* dmat, + void Annotate(const Model& model, const LegacyDMatrix* dmat, int nthread, int verbose); /*! * \brief load branch annotation from a JSON file diff --git a/include/treelite/data.h b/include/treelite/data.h index 88234ea9..f6635275 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -8,12 +8,15 @@ #define TREELITE_DATA_H_ #include +#include #include +#include +#include namespace treelite { /*! \brief a simple data matrix in CSR (Compressed Sparse Row) storage */ -struct DMatrix { +struct LegacyDMatrix { /*! \brief feature values */ std::vector data; /*! \brief feature indices */ @@ -45,8 +48,8 @@ struct DMatrix { * \param verbose whether to produce extra messages * \return newly built DMatrix */ - static DMatrix* Create(const char* filename, const char* format, - int nthread, int verbose); + static LegacyDMatrix* Create(const char* filename, const char* format, + int nthread, int verbose); /*! * \brief construct a new DMatrix from a data parser. The data parser here * refers to any iterable object that streams input data in small @@ -56,8 +59,34 @@ struct DMatrix { * \param verbose whether to produce extra messages * \return newly built DMatrix */ - static DMatrix* Create(dmlc::Parser* parser, - int nthread, int verbose); + static LegacyDMatrix* Create(dmlc::Parser* parser, + int nthread, int verbose); +}; + +class DenseDMatrix { + +}; + +template +class DenseDMatrixImpl : public DenseDMatrix { + static_assert(std::is_same::value || std::is_same::value, + "T must be either float32 or float64"); +}; + +class CSRDMatrix { + private: + std::shared_ptr handle_; + TypeInfo type_; + public: + template + static CSRDMatrix Create(); + static CSRDMatrix Create(TypeInfo type); +}; + +template +class CSRDMatrixImpl : public CSRDMatrix { + static_assert(std::is_same::value || std::is_same::value, + "T must be either float32 or float64"); }; } // namespace treelite diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 10e49af3..51d9d7ec 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -86,7 +86,6 @@ target_sources(objtreelite frontend/lightgbm.cc frontend/xgboost.cc annotator.cc - data.cc filesystem.cc optable.cc reference_serializer.cc @@ -95,7 +94,6 @@ target_sources(objtreelite ${PROJECT_SOURCE_DIR}/include/treelite/c_api.h ${PROJECT_SOURCE_DIR}/include/treelite/compiler.h ${PROJECT_SOURCE_DIR}/include/treelite/compiler_param.h - ${PROJECT_SOURCE_DIR}/include/treelite/data.h ${PROJECT_SOURCE_DIR}/include/treelite/filesystem.h ${PROJECT_SOURCE_DIR}/include/treelite/frontend.h ${PROJECT_SOURCE_DIR}/include/treelite/frontend_impl.h @@ -120,12 +118,14 @@ target_sources(objtreelite_common c_api/c_api_common.cc c_api/c_api_error.cc c_api/c_api_error.h + data/data.cc logging.cc typeinfo.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_common.h ${PROJECT_SOURCE_DIR}/include/treelite/logging.h ${PROJECT_SOURCE_DIR}/include/treelite/math.h ${PROJECT_SOURCE_DIR}/include/treelite/typeinfo.h + ${PROJECT_SOURCE_DIR}/include/treelite/data.h ) msvc_use_static_runtime() diff --git a/src/annotator.cc b/src/annotator.cc index d8a9c1b3..0456a493 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -58,7 +58,7 @@ void Traverse(const treelite::Tree& tree, const E template inline void ComputeBranchLoop(const treelite::ModelImpl& model, - const treelite::DMatrix* dmat, size_t rbegin, size_t rend, + const treelite::LegacyDMatrix* dmat, size_t rbegin, size_t rend, int nthread, const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { const size_t ntree = model.trees.size(); @@ -92,7 +92,7 @@ namespace treelite { template void BranchAnnotator::AnnotateImpl(const treelite::ModelImpl& model, - const treelite::DMatrix* dmat, int nthread, int verbose) { + const treelite::LegacyDMatrix* dmat, int nthread, int verbose) { std::vector new_counts; std::vector counts_tloc; std::vector count_row_ptr; @@ -133,7 +133,7 @@ BranchAnnotator::AnnotateImpl(const treelite::ModelImpl(DMatrix::Create(path, format, - nthread, verbose)); + *out = static_cast(LegacyDMatrix::Create(path, format, + nthread, verbose)); API_END(); } @@ -62,7 +62,7 @@ int TreeliteDMatrixCreateFromCSR(const float* data, size_t num_col, DMatrixHandle* out) { API_BEGIN(); - std::unique_ptr dmat{new DMatrix()}; + std::unique_ptr dmat{new LegacyDMatrix()}; dmat->Clear(); auto& data_ = dmat->data; auto& col_ind_ = dmat->col_ind; @@ -102,7 +102,7 @@ int TreeliteDMatrixCreateFromMat(const float* data, API_BEGIN(); CHECK_LT(num_col, std::numeric_limits::max()) << "num_col argument is too big"; - std::unique_ptr dmat{new DMatrix()}; + std::unique_ptr dmat{new LegacyDMatrix()}; dmat->Clear(); auto& data_ = dmat->data; auto& col_ind_ = dmat->col_ind; @@ -145,7 +145,7 @@ int TreeliteDMatrixGetDimension(DMatrixHandle handle, size_t* out_num_col, size_t* out_nelem) { API_BEGIN(); - const DMatrix* dmat = static_cast(handle); + const LegacyDMatrix* dmat = static_cast(handle); *out_num_row = dmat->num_row; *out_num_col = dmat->num_col; *out_nelem = dmat->nelem; @@ -155,7 +155,7 @@ int TreeliteDMatrixGetDimension(DMatrixHandle handle, int TreeliteDMatrixGetPreview(DMatrixHandle handle, const char** out_preview) { API_BEGIN(); - const DMatrix* dmat = static_cast(handle); + const LegacyDMatrix* dmat = static_cast(handle); std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str; std::ostringstream oss; const size_t iend = (dmat->nelem <= 50) ? dmat->nelem : 25; @@ -186,7 +186,7 @@ int TreeliteDMatrixGetArrays(DMatrixHandle handle, const uint32_t** out_col_ind, const size_t** out_row_ptr) { API_BEGIN(); - const DMatrix* dmat_ = static_cast(handle); + const LegacyDMatrix* dmat_ = static_cast(handle); *out_data = &dmat_->data[0]; *out_col_ind = &dmat_->col_ind[0]; *out_row_ptr = &dmat_->row_ptr[0]; @@ -195,7 +195,7 @@ int TreeliteDMatrixGetArrays(DMatrixHandle handle, int TreeliteDMatrixFree(DMatrixHandle handle) { API_BEGIN(); - delete static_cast(handle); + delete static_cast(handle); API_END(); } @@ -207,7 +207,7 @@ int TreeliteAnnotateBranch(ModelHandle model, API_BEGIN(); std::unique_ptr annotator{new BranchAnnotator()}; const Model* model_ = static_cast(model); - const DMatrix* dmat_ = static_cast(dmat); + const LegacyDMatrix* dmat_ = static_cast(dmat); annotator->Annotate(*model_, dmat_, nthread, verbose); *out = static_cast(annotator.release()); API_END(); diff --git a/src/data.cc b/src/data/data.cc similarity index 89% rename from src/data.cc rename to src/data/data.cc index a7d370e0..df4836f6 100644 --- a/src/data.cc +++ b/src/data/data.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2017-2020 by Contributors - * \file data.h + * \file data.cc * \author Hyunsu Cho * \brief Input data structure of Treelite */ @@ -13,20 +13,20 @@ namespace treelite { -DMatrix* -DMatrix::Create(const char* filename, const char* format, - int nthread, int verbose) { +LegacyDMatrix* +LegacyDMatrix::Create(const char* filename, const char* format, + int nthread, int verbose) { std::unique_ptr> parser( dmlc::Parser::Create(filename, 0, 1, format)); return Create(parser.get(), nthread, verbose); } -DMatrix* -DMatrix::Create(dmlc::Parser* parser, int nthread, int verbose) { +LegacyDMatrix* +LegacyDMatrix::Create(dmlc::Parser* parser, int nthread, int verbose) { const int max_thread = omp_get_max_threads(); nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread); - DMatrix* dmat = new DMatrix(); + LegacyDMatrix* dmat = new LegacyDMatrix(); dmat->Clear(); auto& data_ = dmat->data; auto& col_ind_ = dmat->col_ind; From e4bb3cb84c4bbff090f1d81c601b44ca9fb9292c Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Mon, 31 Aug 2020 21:36:18 -0700 Subject: [PATCH 11/38] New DMatrix design --- include/treelite/data.h | 78 +++++++++++++++++++++++++++++----- src/data/data.cc | 93 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 160 insertions(+), 11 deletions(-) diff --git a/include/treelite/data.h b/include/treelite/data.h index f6635275..b85b7933 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -64,29 +64,85 @@ struct LegacyDMatrix { }; class DenseDMatrix { - + private: + TypeInfo type_; + public: + template + static std::unique_ptr Create( + std::vector data, ElementType missing_value, size_t num_row, size_t num_col); + template + static std::unique_ptr Create( + const void* data, const void* missing_value, size_t num_row, size_t num_col); + static std::unique_ptr Create( + TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col); }; -template +template class DenseDMatrixImpl : public DenseDMatrix { - static_assert(std::is_same::value || std::is_same::value, - "T must be either float32 or float64"); + private: + /*! \brief feature values */ + std::vector data; + /*! \brief value representing the missing value (usually NaN) */ + ElementType missing_value; + /*! \brief number of rows */ + size_t num_row; + /*! \brief number of columns (i.e. # of features used) */ + size_t num_col; + public: + DenseDMatrixImpl() = delete; + DenseDMatrixImpl(std::vector data, ElementType missing_value, size_t num_row, + size_t num_col); + ~DenseDMatrixImpl() = default; + DenseDMatrixImpl(const DenseDMatrixImpl&) = default; + DenseDMatrixImpl(DenseDMatrixImpl&&) noexcept = default; + + friend class DenseDMatrix; + static_assert(std::is_same::value || std::is_same::value, + "ElementType must be either float32 or float64"); }; class CSRDMatrix { private: - std::shared_ptr handle_; TypeInfo type_; public: - template - static CSRDMatrix Create(); - static CSRDMatrix Create(TypeInfo type); + template + static std::unique_ptr Create( + std::vector data, std::vector col_ind, std::vector row_ptr, + size_t num_row, size_t num_col); + template + static std::unique_ptr Create( + const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row, + size_t num_col, size_t num_elem); + static std::unique_ptr Create( + TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, size_t num_elem); }; -template +template class CSRDMatrixImpl : public CSRDMatrix { - static_assert(std::is_same::value || std::is_same::value, - "T must be either float32 or float64"); + private: + /*! \brief feature values */ + std::vector data; + /*! \brief feature indices. col_ind[i] indicates the feature index associated with data[i]. */ + std::vector col_ind; + /*! \brief pointer to row headers; length is [num_row] + 1. */ + std::vector row_ptr; + /*! \brief number of rows */ + size_t num_row; + /*! \brief number of columns (i.e. # of features used) */ + size_t num_col; + + public: + CSRDMatrixImpl() = delete; + CSRDMatrixImpl(std::vector data, std::vector col_ind, + std::vector row_ptr, size_t num_row, size_t num_col); + ~CSRDMatrixImpl() = default; + CSRDMatrixImpl(const CSRDMatrixImpl&) = default; + CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default; + + friend class CSRDMatrix; + static_assert(std::is_same::value || std::is_same::value, + "ElementType must be either float32 or float64"); }; } // namespace treelite diff --git a/src/data/data.cc b/src/data/data.cc index df4836f6..9784c472 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -76,4 +76,97 @@ LegacyDMatrix::Create(dmlc::Parser* parser, int nthread, int verbose) return dmat; } +template +std::unique_ptr +DenseDMatrix::Create( + std::vector data, ElementType missing_value, size_t num_row, size_t num_col) { + std::unique_ptr matrix = std::make_unique>( + std::move(data), missing_value, num_row, num_col + ); + matrix->type_ = InferTypeInfoOf(); + return matrix; +} + +template +std::unique_ptr +DenseDMatrix::Create(const void* data, const void* missing_value, size_t num_row, size_t num_col) { + auto* data_ptr = static_cast(data); + const size_t num_elem = num_row * num_col; + return DenseDMatrix::Create(std::vector(data_ptr, data_ptr + num_elem), + *static_cast(missing_value), num_row, num_col); +} + +std::unique_ptr +DenseDMatrix::Create( + TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col) { + CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid"; + switch (type) { + case TypeInfo::kFloat32: + return Create(data, missing_value, num_row, num_col); + case TypeInfo::kFloat64: + return Create(data, missing_value, num_row, num_col); + case TypeInfo::kInvalid: + case TypeInfo::kUInt32: + default: + LOG(FATAL) << "Invalid type for DenseDMatrix: " << TypeInfoToString(type); + } + return std::unique_ptr(nullptr); +} + +template +DenseDMatrixImpl::DenseDMatrixImpl( + std::vector data, ElementType missing_value, size_t num_row, size_t num_col) + : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row), + num_col(num_col) {} + +template +std::unique_ptr +CSRDMatrix::Create(std::vector data, std::vector col_ind, + std::vector row_ptr, size_t num_row, size_t num_col) { + std::unique_ptr matrix = std::make_unique>( + std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col + ); + matrix->type_ = InferTypeInfoOf(); + return matrix; +} + +template +std::unique_ptr +CSRDMatrix::Create(const void* data, const uint32_t* col_ind, + const size_t* row_ptr, size_t num_row, size_t num_col, size_t num_elem) { + auto* data_ptr = static_cast(data); + return CSRDMatrix::Create( + std::vector(data_ptr, data_ptr + num_elem), + std::vector(col_ind, col_ind + num_elem), + std::vector(row_ptr, row_ptr + num_row + 1), + num_row, + num_col + ); +} + +std::unique_ptr +CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, size_t num_elem) { + CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid"; + switch (type) { + case TypeInfo::kFloat32: + return Create(data, col_ind, row_ptr, num_row, num_col, num_elem); + case TypeInfo::kFloat64: + return Create(data, col_ind, row_ptr, num_row, num_col, num_elem); + case TypeInfo::kInvalid: + case TypeInfo::kUInt32: + default: + LOG(FATAL) << "Invalid type for CSRDMatrix: " << TypeInfoToString(type); + } + return std::unique_ptr(nullptr); +} + +template +CSRDMatrixImpl::CSRDMatrixImpl( + std::vector data, std::vector col_ind, std::vector row_ptr, + size_t num_row, size_t num_col) + : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)), + num_row(num_col), num_col(num_col) +{} + } // namespace treelite From f64ddc835dbcf59df07cfe9118f1eaa5d685b732 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 1 Sep 2020 17:27:28 -0700 Subject: [PATCH 12/38] Use new DMatrix in branch annotator; move DMatrix to objtreelite_common --- include/treelite/annotator.h | 6 +- include/treelite/c_api.h | 103 +----------------- include/treelite/c_api_common.h | 67 ++++++++++++ include/treelite/data.h | 109 +++++++++---------- src/annotator.cc | 53 ++++++---- src/c_api/c_api.cc | 179 ++------------------------------ src/c_api/c_api_common.cc | 60 +++++++++++ src/data/data.cc | 134 +++++++++++++++--------- 8 files changed, 311 insertions(+), 400 deletions(-) diff --git a/include/treelite/annotator.h b/include/treelite/annotator.h index 9b106a8d..a0b4ab2d 100644 --- a/include/treelite/annotator.h +++ b/include/treelite/annotator.h @@ -16,9 +16,6 @@ namespace treelite { /*! \brief branch annotator class */ class BranchAnnotator { public: - template - void AnnotateImpl(const treelite::ModelImpl& model, - const treelite::LegacyDMatrix* dmat, int nthread, int verbose); /*! * \brief annotate branches in a given model using frequency patterns in the * training data. The annotation can be accessed through Get() method. @@ -27,8 +24,7 @@ class BranchAnnotator { * \param nthread number of threads to use * \param verbose whether to produce extra messages */ - void Annotate(const Model& model, const LegacyDMatrix* dmat, - int nthread, int verbose); + void Annotate(const Model& model, const CSRDMatrix* dmat, int nthread, int verbose); /*! * \brief load branch annotation from a JSON file * \param fi input stream diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 441b7355..e1a715d3 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -19,8 +19,6 @@ * opaque handles * \{ */ -/*! \brief handle to a data matrix */ -typedef void* DMatrixHandle; /*! \brief handle to a decision tree ensemble model */ typedef void* ModelHandle; /*! \brief handle to tree builder class */ @@ -35,100 +33,6 @@ typedef void* CompilerHandle; typedef void* ValueHandle; /*! \} */ -/*! - * \defgroup dmatrix - * Data matrix interface - * \{ - */ -/*! - * \brief create DMatrix from a file - * \param path file path - * \param format file format - * \param nthread number of threads to use - * \param verbose whether to produce extra messages - * \param out the created DMatrix - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixCreateFromFile(const char* path, - const char* format, - int nthread, - int verbose, - DMatrixHandle* out); -/*! - * \brief create DMatrix from a (in-memory) CSR matrix - * \param data feature values - * \param col_ind feature indices - * \param row_ptr pointer to row headers - * \param num_row number of rows - * \param num_col number of columns - * \param out the created DMatrix - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixCreateFromCSR(const float* data, - const unsigned* col_ind, - const size_t* row_ptr, - size_t num_row, - size_t num_col, - DMatrixHandle* out); -/*! - * \brief create DMatrix from a (in-memory) dense matrix - * \param data feature values - * \param num_row number of rows - * \param num_col number of columns - * \param missing_value value to represent missing value - * \param out the created DMatrix - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixCreateFromMat(const float* data, - size_t num_row, - size_t num_col, - float missing_value, - DMatrixHandle* out); -/*! - * \brief get dimensions of a DMatrix - * \param handle handle to DMatrix - * \param out_num_row used to set number of rows - * \param out_num_col used to set number of columns - * \param out_nelem used to set number of nonzero entries - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixGetDimension(DMatrixHandle handle, - size_t* out_num_row, - size_t* out_num_col, - size_t* out_nelem); - -/*! - * \brief produce a human-readable preview of a DMatrix - * Will print first and last 25 non-zero entries, along with their locations - * \param handle handle to DMatrix - * \param out_preview used to save the address of the string literal - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixGetPreview(DMatrixHandle handle, - const char** out_preview); - -/*! - * \brief extract three arrays (data, col_ind, row_ptr) that define a DMatrix. - * \param handle handle to DMatrix - * \param out_data used to save pointer to array containing feature values - * \param out_col_ind used to save pointer to array containing feature indices - * \param out_row_ptr used to save pointer to array containing pointers to - * row headers - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixGetArrays(DMatrixHandle handle, - const float** out_data, - const uint32_t** out_col_ind, - const size_t** out_row_ptr); - -/*! - * \brief delete DMatrix from memory - * \param handle handle to DMatrix - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle); -/*! \} */ - /*! * \defgroup annotator * Branch annotator interface @@ -144,11 +48,8 @@ TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle); * \param out used to save handle for the created annotation * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreeliteAnnotateBranch(ModelHandle model, - DMatrixHandle dmat, - int nthread, - int verbose, - AnnotationHandle* out); +TREELITE_DLL int TreeliteAnnotateBranch( + ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out); /*! * \brief save branch annotation to a JSON file * \param handle annotation to save diff --git a/include/treelite/c_api_common.h b/include/treelite/c_api_common.h index c598309d..b3c21364 100644 --- a/include/treelite/c_api_common.h +++ b/include/treelite/c_api_common.h @@ -26,6 +26,9 @@ #define TREELITE_DLL TREELITE_EXTERN_C #endif +/*! \brief handle to a data matrix */ +typedef void* DMatrixHandle; + /*! * \brief display last error; can be called by multiple threads * Note. Each thread will get the last error occured in its own context. @@ -42,4 +45,68 @@ TREELITE_DLL const char* TreeliteGetLastError(void); */ TREELITE_DLL int TreeliteRegisterLogCallback(void (*callback)(const char*)); +/*! + * \defgroup dmatrix + * Data matrix interface + * \{ + */ +/*! + * \brief create a sparse DMatrix from a file + * \param path file path + * \param format file format + * \param nthread number of threads to use + * \param verbose whether to produce extra messages + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromFile( + const char* path, const char* format, int nthread, int verbose, DMatrixHandle* out); +/*! + * \brief create DMatrix from a (in-memory) CSR matrix + * \param data feature values + * \param data_type Type of data elements + * \param col_ind feature indices + * \param row_ptr pointer to row headers + * \param num_row number of rows + * \param num_col number of columns + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromCSR( + const void* data, const char* data_type, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, DMatrixHandle* out); +/*! + * \brief create DMatrix from a (in-memory) dense matrix + * \param data feature values + * \param data_type Type of data elements + * \param num_row number of rows + * \param num_col number of columns + * \param missing_value value to represent missing value + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromMat( + const void* data, const char* data_type, size_t num_row, size_t num_col, + const void* missing_value, DMatrixHandle* out); +/*! + * \brief get dimensions of a DMatrix + * \param handle handle to DMatrix + * \param out_num_row used to set number of rows + * \param out_num_col used to set number of columns + * \param out_nelem used to set number of nonzero entries + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixGetDimension(DMatrixHandle handle, + size_t* out_num_row, + size_t* out_num_col, + size_t* out_nelem); + +/*! + * \brief delete DMatrix from memory + * \param handle handle to DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle); +/*! \} */ + #endif // TREELITE_C_API_COMMON_H_ diff --git a/include/treelite/data.h b/include/treelite/data.h index b85b7933..d32078c7 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -15,57 +15,18 @@ namespace treelite { -/*! \brief a simple data matrix in CSR (Compressed Sparse Row) storage */ -struct LegacyDMatrix { - /*! \brief feature values */ - std::vector data; - /*! \brief feature indices */ - std::vector col_ind; - /*! \brief pointer to row headers; length of [num_row] + 1 */ - std::vector row_ptr; - /*! \brief number of rows */ - size_t num_row; - /*! \brief number of columns */ - size_t num_col; - /*! \brief number of nonzero entries */ - size_t nelem; - - /*! - * \brief clear all data fields - */ - inline void Clear() { - data.clear(); - row_ptr.clear(); - col_ind.clear(); - row_ptr.resize(1, 0); - num_row = num_col = nelem = 0; - } - /*! - * \brief construct a new DMatrix from a file - * \param filename name of file - * \param format format of file (libsvm/libfm/csv) - * \param nthread number of threads to use - * \param verbose whether to produce extra messages - * \return newly built DMatrix - */ - static LegacyDMatrix* Create(const char* filename, const char* format, - int nthread, int verbose); - /*! - * \brief construct a new DMatrix from a data parser. The data parser here - * refers to any iterable object that streams input data in small - * batches. - * \param parser pointer to data parser - * \param nthread number of threads to use - * \param verbose whether to produce extra messages - * \return newly built DMatrix - */ - static LegacyDMatrix* Create(dmlc::Parser* parser, - int nthread, int verbose); +class DMatrix { + public: + virtual size_t GetNumRow() const = 0; + virtual size_t GetNumCol() const = 0; + virtual size_t GetNumElem() const = 0; + DMatrix() = default; + virtual ~DMatrix() = default; }; -class DenseDMatrix { +class DenseDMatrix : public DMatrix { private: - TypeInfo type_; + TypeInfo element_type_; public: template static std::unique_ptr Create( @@ -75,6 +36,9 @@ class DenseDMatrix { const void* data, const void* missing_value, size_t num_row, size_t num_col); static std::unique_ptr Create( TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col); + size_t GetNumRow() const override = 0; + size_t GetNumCol() const override = 0; + size_t GetNumElem() const override = 0; }; template @@ -95,15 +59,21 @@ class DenseDMatrixImpl : public DenseDMatrix { ~DenseDMatrixImpl() = default; DenseDMatrixImpl(const DenseDMatrixImpl&) = default; DenseDMatrixImpl(DenseDMatrixImpl&&) noexcept = default; + DenseDMatrixImpl& operator=(const DenseDMatrixImpl&) = default; + DenseDMatrixImpl& operator=(DenseDMatrixImpl&&) noexcept = default; + + size_t GetNumRow() const override; + size_t GetNumCol() const override; + size_t GetNumElem() const override; friend class DenseDMatrix; static_assert(std::is_same::value || std::is_same::value, "ElementType must be either float32 or float64"); }; -class CSRDMatrix { +class CSRDMatrix : public DMatrix { private: - TypeInfo type_; + TypeInfo element_type_; public: template static std::unique_ptr Create( @@ -112,15 +82,23 @@ class CSRDMatrix { template static std::unique_ptr Create( const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row, - size_t num_col, size_t num_elem); + size_t num_col); static std::unique_ptr Create( TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, - size_t num_row, size_t num_col, size_t num_elem); + size_t num_row, size_t num_col); + static std::unique_ptr Create( + const char* filename, const char* format, int nthread, int verbose); + TypeInfo GetMatrixType() const; + template + inline auto Dispatch(Func func) const; + size_t GetNumRow() const override = 0; + size_t GetNumCol() const override = 0; + size_t GetNumElem() const override = 0; }; template class CSRDMatrixImpl : public CSRDMatrix { - private: + public: /*! \brief feature values */ std::vector data; /*! \brief feature indices. col_ind[i] indicates the feature index associated with data[i]. */ @@ -132,19 +110,42 @@ class CSRDMatrixImpl : public CSRDMatrix { /*! \brief number of columns (i.e. # of features used) */ size_t num_col; - public: CSRDMatrixImpl() = delete; CSRDMatrixImpl(std::vector data, std::vector col_ind, std::vector row_ptr, size_t num_row, size_t num_col); ~CSRDMatrixImpl() = default; CSRDMatrixImpl(const CSRDMatrixImpl&) = default; CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default; + CSRDMatrixImpl& operator=(const CSRDMatrixImpl&) = default; + CSRDMatrixImpl& operator=(CSRDMatrixImpl&&) noexcept = default; + + size_t GetNumRow() const override; + size_t GetNumCol() const override; + size_t GetNumElem() const override; friend class CSRDMatrix; static_assert(std::is_same::value || std::is_same::value, "ElementType must be either float32 or float64"); }; +template +inline auto +CSRDMatrix::Dispatch(Func func) const { + switch (element_type_) { + case TypeInfo::kFloat32: + return func(*dynamic_cast*>(this)); + break; + case TypeInfo::kFloat64: + return func(*dynamic_cast*>(this)); + break; + case TypeInfo::kUInt32: + case TypeInfo::kInvalid: + default: + LOG(FATAL) << "Invalid element type for the matrix: " << TypeInfoToString(element_type_); + return func(*dynamic_cast*>(this)); // avoid missing return error + } +} + } // namespace treelite #endif // TREELITE_DATA_H_ diff --git a/src/annotator.cc b/src/annotator.cc index 0456a493..e1fe7656 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -13,14 +13,15 @@ namespace { +template union Entry { int missing; - float fvalue; + ThresholdType fvalue; }; template -void Traverse_(const treelite::Tree& tree, const Entry* data, - int nid, size_t* out_counts) { +void Traverse_(const treelite::Tree& tree, + const Entry* data, int nid, size_t* out_counts) { ++out_counts[nid]; if (!tree.IsLeaf(nid)) { const unsigned split_index = tree.SplitIndex(nid); @@ -51,16 +52,16 @@ void Traverse_(const treelite::Tree& tree, const } template -void Traverse(const treelite::Tree& tree, const Entry* data, - size_t* out_counts) { +void Traverse(const treelite::Tree& tree, + const Entry* data, size_t* out_counts) { Traverse_(tree, data, 0, out_counts); } template inline void ComputeBranchLoop(const treelite::ModelImpl& model, - const treelite::LegacyDMatrix* dmat, size_t rbegin, size_t rend, - int nthread, const size_t* count_row_ptr, size_t* counts_tloc, - Entry* inst) { + const treelite::CSRDMatrixImpl* dmat, size_t rbegin, + size_t rend,int nthread, const size_t* count_row_ptr, + size_t* counts_tloc, Entry* inst) { const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -90,12 +91,18 @@ inline void ComputeBranchLoop(const treelite::ModelImpl -void -BranchAnnotator::AnnotateImpl(const treelite::ModelImpl& model, - const treelite::LegacyDMatrix* dmat, int nthread, int verbose) { +inline void +AnnotateImpl( + const treelite::ModelImpl& model, + const treelite::CSRDMatrix* dmat, int nthread, int verbose, + std::vector>* out_counts) { + auto* dmat_ = dynamic_cast*>(dmat); + CHECK(dmat_) << "BranchAnnotator: Dangling reference to CSRDMatrix detected"; + std::vector new_counts; std::vector counts_tloc; std::vector count_row_ptr; + count_row_ptr = {0}; const size_t ntree = model.trees.size(); const int max_thread = omp_get_max_threads(); @@ -106,15 +113,15 @@ BranchAnnotator::AnnotateImpl(const treelite::ModelImpl inst(nthread * dmat->num_col, {-1}); - const size_t pstep = (dmat->num_row + 19) / 20; + std::vector> inst(nthread * dmat_->num_col, {-1}); + const size_t pstep = (dmat_->num_row + 19) / 20; // interval to display progress - for (size_t rbegin = 0; rbegin < dmat->num_row; rbegin += pstep) { - const size_t rend = std::min(rbegin + pstep, dmat->num_row); - ComputeBranchLoop(model, dmat, rbegin, rend, nthread, + for (size_t rbegin = 0; rbegin < dmat_->num_row; rbegin += pstep) { + const size_t rend = std::min(rbegin + pstep, dmat_->num_row); + ComputeBranchLoop(model, dmat_, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0], &inst[0]); if (verbose > 0) { - LOG(INFO) << rend << " of " << dmat->num_row << " rows processed"; + LOG(INFO) << rend << " of " << dmat_->num_row << " rows processed"; } } @@ -127,15 +134,19 @@ BranchAnnotator::AnnotateImpl(const treelite::ModelImpl>& counts = *out_counts; for (size_t i = 0; i < ntree; ++i) { - this->counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]); + counts.emplace_back(&new_counts[count_row_ptr[i]], &new_counts[count_row_ptr[i + 1]]); } } void -BranchAnnotator::Annotate(const Model& model, const LegacyDMatrix* dmat, int nthread, int verbose) { - model.Dispatch([this, dmat, nthread, verbose](auto& handle) { - AnnotateImpl(handle, dmat, nthread, verbose); +BranchAnnotator::Annotate(const Model& model, const CSRDMatrix* dmat, int nthread, int verbose) { + TypeInfo threshold_type = model.GetThresholdType(); + model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) { + CHECK(dmat->GetMatrixType() == threshold_type) + << "BranchAnnotator: the matrix type must match the threshold type of the model"; + AnnotateImpl(handle, dmat, nthread, verbose, &this->counts); }); } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1d321e20..531a367c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -32,183 +32,18 @@ struct CompilerHandleImpl { ~CompilerHandleImpl() = default; }; -/*! \brief entry to to easily hold returning information */ -struct TreeliteAPIThreadLocalEntry { - /*! \brief result holder for returning string */ - std::string ret_str; -}; - -// define threadlocal store for returning information -using TreeliteAPIThreadLocalStore - = dmlc::ThreadLocalStore; - } // anonymous namespace -int TreeliteDMatrixCreateFromFile(const char* path, - const char* format, - int nthread, - int verbose, - DMatrixHandle* out) { - API_BEGIN(); - *out = static_cast(LegacyDMatrix::Create(path, format, - nthread, verbose)); - API_END(); -} - -int TreeliteDMatrixCreateFromCSR(const float* data, - const unsigned* col_ind, - const size_t* row_ptr, - size_t num_row, - size_t num_col, - DMatrixHandle* out) { - API_BEGIN(); - std::unique_ptr dmat{new LegacyDMatrix()}; - dmat->Clear(); - auto& data_ = dmat->data; - auto& col_ind_ = dmat->col_ind; - auto& row_ptr_ = dmat->row_ptr; - data_.reserve(row_ptr[num_row]); - col_ind_.reserve(row_ptr[num_row]); - row_ptr_.reserve(num_row + 1); - for (size_t i = 0; i < num_row; ++i) { - const size_t jbegin = row_ptr[i]; - const size_t jend = row_ptr[i + 1]; - for (size_t j = jbegin; j < jend; ++j) { - if (!math::CheckNAN(data[j])) { // skip NaN - data_.push_back(data[j]); - CHECK_LT(col_ind[j], std::numeric_limits::max()) - << "feature index too big to fit into uint32_t"; - col_ind_.push_back(static_cast(col_ind[j])); - } - } - row_ptr_.push_back(data_.size()); - } - data_.shrink_to_fit(); - col_ind_.shrink_to_fit(); - dmat->num_row = num_row; - dmat->num_col = num_col; - dmat->nelem = data_.size(); // some nonzeros may have been deleted as NAN - - *out = static_cast(dmat.release()); - API_END(); -} - -int TreeliteDMatrixCreateFromMat(const float* data, - size_t num_row, - size_t num_col, - float missing_value, - DMatrixHandle* out) { - const bool nan_missing = math::CheckNAN(missing_value); - API_BEGIN(); - CHECK_LT(num_col, std::numeric_limits::max()) - << "num_col argument is too big"; - std::unique_ptr dmat{new LegacyDMatrix()}; - dmat->Clear(); - auto& data_ = dmat->data; - auto& col_ind_ = dmat->col_ind; - auto& row_ptr_ = dmat->row_ptr; - // make an educated guess for initial sizes, - // so as to present initial wave of allocation - const size_t guess_size - = std::min(std::min(num_row * num_col, num_row * 1000), - static_cast(64 * 1024 * 1024)); - data_.reserve(guess_size); - col_ind_.reserve(guess_size); - row_ptr_.reserve(num_row + 1); - const float* row = &data[0]; // points to beginning of each row - for (size_t i = 0; i < num_row; ++i, row += num_col) { - for (size_t j = 0; j < num_col; ++j) { - if (math::CheckNAN(row[j])) { - CHECK(nan_missing) - << "The missing_value argument must be set to NaN if there is any " - << "NaN in the matrix."; - } else if (nan_missing || row[j] != missing_value) { - // row[j] is a valid entry - data_.push_back(row[j]); - col_ind_.push_back(static_cast(j)); - } - } - row_ptr_.push_back(data_.size()); - } - data_.shrink_to_fit(); - col_ind_.shrink_to_fit(); - dmat->num_row = num_row; - dmat->num_col = num_col; - dmat->nelem = data_.size(); // some nonzeros may have been deleted as NaN - - *out = static_cast(dmat.release()); - API_END(); -} - -int TreeliteDMatrixGetDimension(DMatrixHandle handle, - size_t* out_num_row, - size_t* out_num_col, - size_t* out_nelem) { - API_BEGIN(); - const LegacyDMatrix* dmat = static_cast(handle); - *out_num_row = dmat->num_row; - *out_num_col = dmat->num_col; - *out_nelem = dmat->nelem; - API_END(); -} - -int TreeliteDMatrixGetPreview(DMatrixHandle handle, - const char** out_preview) { - API_BEGIN(); - const LegacyDMatrix* dmat = static_cast(handle); - std::string& ret_str = TreeliteAPIThreadLocalStore::Get()->ret_str; - std::ostringstream oss; - const size_t iend = (dmat->nelem <= 50) ? dmat->nelem : 25; - for (size_t i = 0; i < iend; ++i) { - const size_t row_ind = - std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i) - - &dmat->row_ptr[0] - 1; - oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t" - << dmat->data[i] << "\n"; - } - if (dmat->nelem > 50) { - oss << " :\t:\n"; - for (size_t i = dmat->nelem - 25; i < dmat->nelem; ++i) { - const size_t row_ind = - std::upper_bound(&dmat->row_ptr[0], &dmat->row_ptr[dmat->num_row + 1], i) - - &dmat->row_ptr[0] - 1; - oss << " (" << row_ind << ", " << dmat->col_ind[i] << ")\t" - << dmat->data[i] << "\n"; - } - } - ret_str = oss.str(); - *out_preview = ret_str.c_str(); - API_END(); -} - -int TreeliteDMatrixGetArrays(DMatrixHandle handle, - const float** out_data, - const uint32_t** out_col_ind, - const size_t** out_row_ptr) { - API_BEGIN(); - const LegacyDMatrix* dmat_ = static_cast(handle); - *out_data = &dmat_->data[0]; - *out_col_ind = &dmat_->col_ind[0]; - *out_row_ptr = &dmat_->row_ptr[0]; - API_END(); -} - -int TreeliteDMatrixFree(DMatrixHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TreeliteAnnotateBranch(ModelHandle model, - DMatrixHandle dmat, - int nthread, - int verbose, - AnnotationHandle* out) { +int TreeliteAnnotateBranch( + ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out) { API_BEGIN(); std::unique_ptr annotator{new BranchAnnotator()}; const Model* model_ = static_cast(model); - const LegacyDMatrix* dmat_ = static_cast(dmat); - annotator->Annotate(*model_, dmat_, nthread, verbose); + const auto* dmat_ = static_cast(dmat); + CHECK(dmat_) << "Found a dangling reference to DMatrix"; + const auto* csr_dmat_ = dynamic_cast(dmat_); + CHECK(csr_dmat_) << "Annotator supports a sparse DMatrix for now"; + annotator->Annotate(*model_, csr_dmat_, nthread, verbose); *out = static_cast(annotator.release()); API_END(); } diff --git a/src/c_api/c_api_common.cc b/src/c_api/c_api_common.cc index 040d02ef..1a8fa925 100644 --- a/src/c_api/c_api_common.cc +++ b/src/c_api/c_api_common.cc @@ -5,15 +5,75 @@ * \brief C API of treelite (this file is used by both runtime and main package) */ +#include #include +#include #include #include "./c_api_error.h" using namespace treelite; +/*! \brief entry to to easily hold returning information */ +struct TreeliteAPIThreadLocalEntry { + /*! \brief result holder for returning string */ + std::string ret_str; +}; + +// define threadlocal store for returning information +using TreeliteAPIThreadLocalStore + = dmlc::ThreadLocalStore; + int TreeliteRegisterLogCallback(void (*callback)(const char*)) { API_BEGIN(); LogCallbackRegistry* registry = LogCallbackRegistryStore::Get(); registry->Register(callback); API_END(); } + +int TreeliteDMatrixCreateFromFile( + const char* path, const char* format, int nthread, int verbose, DMatrixHandle* out) { + API_BEGIN(); + std::unique_ptr mat = CSRDMatrix::Create(path, format, nthread, verbose); + *out = static_cast(mat.release()); + API_END(); +} + +int TreeliteDMatrixCreateFromCSR( + const void* data, const char* data_type_str, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, DMatrixHandle* out) { + API_BEGIN(); + TypeInfo data_type = typeinfo_table.at(data_type_str); + std::unique_ptr matrix + = CSRDMatrix::Create(data_type, data, col_ind, row_ptr, num_row, num_col); + *out = static_cast(matrix.release()); + API_END(); +} + +int TreeliteDMatrixCreateFromMat( + const void* data, const char* data_type_str, size_t num_row, size_t num_col, + const void* missing_value, DMatrixHandle* out) { + API_BEGIN(); + TypeInfo data_type = typeinfo_table.at(data_type_str); + std::unique_ptr matrix + = DenseDMatrix::Create(data_type, data, missing_value, num_row, num_col); + *out = static_cast(matrix.release()); + API_END(); +} + +int TreeliteDMatrixGetDimension(DMatrixHandle handle, + size_t* out_num_row, + size_t* out_num_col, + size_t* out_nelem) { + API_BEGIN(); + const DMatrix* dmat = static_cast(handle); + *out_num_row = dmat->GetNumRow(); + *out_num_col = dmat->GetNumCol(); + *out_nelem = dmat->GetNumElem(); + API_END(); +} + +int TreeliteDMatrixFree(DMatrixHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} diff --git a/src/data/data.cc b/src/data/data.cc index 9784c472..268d6521 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -11,71 +11,62 @@ #include #include -namespace treelite { - -LegacyDMatrix* -LegacyDMatrix::Create(const char* filename, const char* format, - int nthread, int verbose) { - std::unique_ptr> parser( - dmlc::Parser::Create(filename, 0, 1, format)); - return Create(parser.get(), nthread, verbose); -} +namespace { -LegacyDMatrix* -LegacyDMatrix::Create(dmlc::Parser* parser, int nthread, int verbose) { +std::unique_ptr +CreateFromParser(dmlc::Parser* parser, int nthread, int verbose) { const int max_thread = omp_get_max_threads(); nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread); - LegacyDMatrix* dmat = new LegacyDMatrix(); - dmat->Clear(); - auto& data_ = dmat->data; - auto& col_ind_ = dmat->col_ind; - auto& row_ptr_ = dmat->row_ptr; - auto& num_row_ = dmat->num_row; - auto& num_col_ = dmat->num_col; - auto& nelem_ = dmat->nelem; + std::vector data; + std::vector col_ind; + std::vector row_ptr; + size_t num_row = 0; + size_t num_col = 0; + size_t num_elem = 0; std::vector max_col_ind(nthread, 0); parser->BeforeFirst(); while (parser->Next()) { - const dmlc::RowBlock& batch = parser->Value(); - num_row_ += batch.size; - nelem_ += batch.offset[batch.size]; - const size_t top = data_.size(); - data_.resize(top + batch.offset[batch.size] - batch.offset[0]); - col_ind_.resize(top + batch.offset[batch.size] - batch.offset[0]); + const dmlc::RowBlock& batch = parser->Value(); + num_row += batch.size; + num_elem += batch.offset[batch.size]; + const size_t top = data.size(); + data.resize(top + batch.offset[batch.size] - batch.offset[0]); + col_ind.resize(top + batch.offset[batch.size] - batch.offset[0]); CHECK_LT(static_cast(batch.offset[batch.size]), std::numeric_limits::max()); #pragma omp parallel for schedule(static) num_threads(nthread) for (int64_t i = static_cast(batch.offset[0]); - i < static_cast(batch.offset[batch.size]); ++i) { + i < static_cast(batch.offset[batch.size]); ++i) { const int tid = omp_get_thread_num(); const uint32_t index = batch.index[i]; - const float fvalue = (batch.value == nullptr) ? 1.0f : - static_cast(batch.value[i]); + const float fvalue = (batch.value == nullptr) ? 1.0f : static_cast(batch.value[i]); const size_t offset = top + i - batch.offset[0]; - data_[offset] = fvalue; - col_ind_[offset] = index; - max_col_ind[tid] = std::max(max_col_ind[tid], - static_cast(index)); + data[offset] = fvalue; + col_ind[offset] = index; + max_col_ind[tid] = std::max(max_col_ind[tid], static_cast(index)); } - const size_t rtop = row_ptr_.size(); - row_ptr_.resize(rtop + batch.size); - CHECK_LT(static_cast(batch.size), - std::numeric_limits::max()); + const size_t rtop = row_ptr.size(); + row_ptr.resize(rtop + batch.size); + CHECK_LT(static_cast(batch.size), std::numeric_limits::max()); #pragma omp parallel for schedule(static) num_threads(nthread) for (int64_t i = 0; i < static_cast(batch.size); ++i) { - row_ptr_[rtop + i] - = row_ptr_[rtop - 1] + batch.offset[i + 1] - batch.offset[0]; + row_ptr[rtop + i] = row_ptr[rtop - 1] + batch.offset[i + 1] - batch.offset[0]; } if (verbose > 0) { - LOG(INFO) << num_row_ << " rows read into memory"; + LOG(INFO) << num_row << " rows read into memory"; } } - num_col_ = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1; - return dmat; + num_col = *std::max_element(max_col_ind.begin(), max_col_ind.end()) + 1; + return treelite::CSRDMatrix::Create(std::move(data), std::move(col_ind), std::move(row_ptr), + num_row, num_col); } +} // anonymous namespace + +namespace treelite { + template std::unique_ptr DenseDMatrix::Create( @@ -83,7 +74,7 @@ DenseDMatrix::Create( std::unique_ptr matrix = std::make_unique>( std::move(data), missing_value, num_row, num_col ); - matrix->type_ = InferTypeInfoOf(); + matrix->element_type_ = InferTypeInfoOf(); return matrix; } @@ -119,6 +110,24 @@ DenseDMatrixImpl::DenseDMatrixImpl( : DenseDMatrix(), data(std::move(data)), missing_value(missing_value), num_row(num_row), num_col(num_col) {} +template +size_t +DenseDMatrixImpl::GetNumRow() const { + return num_row; +} + +template +size_t +DenseDMatrixImpl::GetNumCol() const { + return num_col; +} + +template +size_t +DenseDMatrixImpl::GetNumElem() const { + return num_row * num_col; +} + template std::unique_ptr CSRDMatrix::Create(std::vector data, std::vector col_ind, @@ -126,15 +135,16 @@ CSRDMatrix::Create(std::vector data, std::vector col_ind, std::unique_ptr matrix = std::make_unique>( std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col ); - matrix->type_ = InferTypeInfoOf(); + matrix->element_type_ = InferTypeInfoOf(); return matrix; } template std::unique_ptr CSRDMatrix::Create(const void* data, const uint32_t* col_ind, - const size_t* row_ptr, size_t num_row, size_t num_col, size_t num_elem) { + const size_t* row_ptr, size_t num_row, size_t num_col) { auto* data_ptr = static_cast(data); + const size_t num_elem = row_ptr[num_row]; return CSRDMatrix::Create( std::vector(data_ptr, data_ptr + num_elem), std::vector(col_ind, col_ind + num_elem), @@ -146,13 +156,13 @@ CSRDMatrix::Create(const void* data, const uint32_t* col_ind, std::unique_ptr CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, - size_t num_row, size_t num_col, size_t num_elem) { + size_t num_row, size_t num_col) { CHECK(type != TypeInfo::kInvalid) << "ElementType cannot be invalid"; switch (type) { case TypeInfo::kFloat32: - return Create(data, col_ind, row_ptr, num_row, num_col, num_elem); + return Create(data, col_ind, row_ptr, num_row, num_col); case TypeInfo::kFloat64: - return Create(data, col_ind, row_ptr, num_row, num_col, num_elem); + return Create(data, col_ind, row_ptr, num_row, num_col); case TypeInfo::kInvalid: case TypeInfo::kUInt32: default: @@ -161,6 +171,18 @@ CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, con return std::unique_ptr(nullptr); } +std::unique_ptr +CSRDMatrix::Create(const char* filename, const char* format, int nthread, int verbose) { + std::unique_ptr> parser( + dmlc::Parser::Create(filename, 0, 1, format)); + return CreateFromParser(parser.get(), nthread, verbose); +} + +TypeInfo +CSRDMatrix::GetMatrixType() const { + return element_type_; +} + template CSRDMatrixImpl::CSRDMatrixImpl( std::vector data, std::vector col_ind, std::vector row_ptr, @@ -169,4 +191,22 @@ CSRDMatrixImpl::CSRDMatrixImpl( num_row(num_col), num_col(num_col) {} +template +size_t +CSRDMatrixImpl::GetNumRow() const { + return num_row; +} + +template +size_t +CSRDMatrixImpl::GetNumCol() const { + return num_col; +} + +template +size_t +CSRDMatrixImpl::GetNumElem() const { + return row_ptr.at(num_row); +} + } // namespace treelite From 1c74cc51407dd6f1911909944d660653665ff177 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 1 Sep 2020 21:38:52 -0700 Subject: [PATCH 13/38] Make ModelImpl a derived class of Model, simplifying dispatch --- include/treelite/frontend.h | 16 +-- include/treelite/tree.h | 103 ++++++++------- include/treelite/tree_impl.h | 242 +++++++++++++++++------------------ src/c_api/c_api.cc | 33 ++--- src/compiler/failsafe.cc | 2 +- src/frontend/builder.cc | 10 +- src/frontend/lightgbm.cc | 62 ++++----- src/frontend/xgboost.cc | 56 ++++---- tests/cpp/test_serializer.cc | 19 +-- 9 files changed, 262 insertions(+), 281 deletions(-) diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index d703cb5d..50326c79 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -25,23 +25,23 @@ namespace frontend { * \brief load a model file generated by LightGBM (Microsoft/LightGBM). The * model file must contain a decision tree ensemble. * \param filename name of model file - * \param out reference to loaded model + * \return loaded model */ -void LoadLightGBMModel(const char *filename, Model* out); +std::unique_ptr LoadLightGBMModel(const char *filename); /*! * \brief load a model file generated by XGBoost (dmlc/xgboost). The model file * must contain a decision tree ensemble. * \param filename name of model file - * \param out reference to loaded model + * \return loaded model */ -void LoadXGBoostModel(const char* filename, Model* out); +std::unique_ptr LoadXGBoostModel(const char* filename); /*! * \brief load an XGBoost model from a memory buffer. * \param buf memory buffer * \param len size of memory buffer - * \param out reference to loaded model + * \return loaded model */ -void LoadXGBoostModel(const void* buf, size_t len, Model* out); +std::unique_ptr LoadXGBoostModel(const void* buf, size_t len); //-------------------------------------------------------------------------- // model builder interface: build trees incrementally @@ -220,9 +220,9 @@ class ModelBuilder { void DeleteTree(int index); /*! * \brief finalize the model and produce the in-memory representation - * \param out_model place to store in-memory representation of the finished model + * \return the finished model */ - void CommitModel(Model* out_model); + std::unique_ptr CommitModel(); private: std::unique_ptr pimpl_; // Pimpl pattern diff --git a/include/treelite/tree.h b/include/treelite/tree.h index e8171f1a..3660929b 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -419,9 +419,54 @@ static_assert(std::is_standard_layout::value, inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg); +enum class ModelType : uint16_t { + // Threshold type, + kInvalid = 0, + kFloat32ThresholdUInt32LeafOutput = 1, + kFloat32ThresholdFloat32LeafOutput = 2, + kFloat64ThresholdUInt32LeafOutput = 3, + kFloat64ThresholdFloat64LeafOutput = 4 +}; + +class Model { + private: + ModelType type_; + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; + virtual void GetPyBuffer(std::vector* dest) = 0; + virtual void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end) = 0; + + public: + Model() = default; + virtual ~Model() = default; + template + inline static ModelType InferModelTypeOf(); + template + inline static std::unique_ptr Create(); + inline static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type); + inline ModelType GetModelType() const; + inline TypeInfo GetThresholdType() const; + inline TypeInfo GetLeafOutputType() const; + template + inline auto Dispatch(Func func); + template + inline auto Dispatch(Func func) const; + virtual ModelParam GetParam() const = 0; + virtual int GetNumFeature() const = 0; + virtual int GetNumOutputGroup() const = 0; + virtual bool GetRandomForestFlag() const = 0; + virtual size_t GetNumTree() const = 0; + virtual void SetTreeLimit(size_t limit) = 0; + virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0; + inline std::vector GetPyBuffer(); + inline static std::unique_ptr CreateFromPyBuffer(std::vector frames); +}; + /*! \brief thin wrapper for tree ensemble model */ template -struct ModelImpl { +class ModelImpl : public Model { + public: /*! \brief member trees */ std::vector> trees; /*! @@ -446,58 +491,20 @@ struct ModelImpl { ModelImpl(ModelImpl&&) noexcept = default; ModelImpl& operator=(ModelImpl&&) noexcept = default; - void ReferenceSerialize(dmlc::Stream* fo) const; + inline ModelParam GetParam() const override; + inline int GetNumFeature() const override; + inline int GetNumOutputGroup() const override; + inline bool GetRandomForestFlag() const override; + inline size_t GetNumTree() const override; + inline void SetTreeLimit(size_t limit) override; + void ReferenceSerialize(dmlc::Stream* fo) const override; - inline void GetPyBuffer(std::vector* dest); + inline void GetPyBuffer(std::vector* dest) override; inline void InitFromPyBuffer(std::vector::iterator begin, - std::vector::iterator end); + std::vector::iterator end) override; inline ModelImpl Clone() const; }; -enum class ModelType : uint16_t { - // Threshold type, - kInvalid = 0, - kFloat32ThresholdUInt32LeafOutput = 1, - kFloat32ThresholdFloat32LeafOutput = 2, - kFloat64ThresholdUInt32LeafOutput = 3, - kFloat64ThresholdFloat64LeafOutput = 4 -}; - -struct Model { - private: - std::shared_ptr handle_; - ModelType type_; - TypeInfo threshold_type_; - TypeInfo leaf_output_type_; - - public: - template - inline static ModelType InferModelTypeOf(); - template - inline static Model Create(); - inline static Model Create(TypeInfo threshold_type, TypeInfo leaf_output_type); - template - inline ModelImpl& GetImpl(); - template - inline const ModelImpl& GetImpl() const; - inline ModelType GetModelType() const; - inline TypeInfo GetThresholdType() const; - inline TypeInfo GetLeafOutputType() const; - template - inline auto Dispatch(Func func); - template - inline auto Dispatch(Func func) const; - inline ModelParam GetParam() const; - inline int GetNumFeature() const; - inline int GetNumOutputGroup() const; - inline bool GetRandomForestFlag() const; - inline size_t GetNumTree() const; - inline void SetTreeLimit(size_t limit); - inline void ReferenceSerialize(dmlc::Stream* fo) const; - inline std::vector GetPyBuffer(); - inline static Model CreateFromPyBuffer(std::vector frames); -}; - } // namespace treelite #include "tree_impl.h" diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index b2f421dd..122e06c0 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -419,47 +419,6 @@ Tree::InitFromPyBuffer( InitArrayFromPyBuffer(&left_categories_offset_, *begin++); } -template -inline void -ModelImpl::GetPyBuffer(std::vector* dest) { - /* Header */ - dest->push_back(GetPyBufferFromScalar(&num_feature)); - dest->push_back(GetPyBufferFromScalar(&num_output_group)); - dest->push_back(GetPyBufferFromScalar(&random_forest_flag)); - dest->push_back(GetPyBufferFromScalar( - ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); - - /* Body */ - for (Tree& tree : trees) { - tree.GetPyBuffer(dest); - } -} - -template -inline void -ModelImpl::InitFromPyBuffer( - std::vector::iterator begin, std::vector::iterator end) { - const size_t num_frame = std::distance(begin, end); - /* Header */ - constexpr size_t kNumFrameInHeader = 4; - if (num_frame < kNumFrameInHeader) { - throw std::runtime_error("Wrong number of frames"); - } - InitScalarFromPyBuffer(&num_feature, *begin++); - InitScalarFromPyBuffer(&num_output_group, *begin++); - InitScalarFromPyBuffer(&random_forest_flag, *begin++); - InitScalarFromPyBuffer(¶m, *begin++); - /* Body */ - if ((num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) { - throw std::runtime_error("Wrong number of frames"); - } - trees.clear(); - for (; begin < end; begin += kNumFramePerTree) { - trees.emplace_back(); - trees.back().InitFromPyBuffer(begin, begin + kNumFramePerTree); - } -} - template inline void Tree::Node::Init() { cleft_ = cright_ = -1; @@ -795,32 +754,6 @@ Tree::SetGain(int nid, double gain) { node.gain_present_ = true; } -template -inline ModelImpl -ModelImpl::Clone() const { - ModelImpl model; - for (const Tree& t : trees) { - model.trees.push_back(t.Clone()); - } - model.num_feature = num_feature; - model.num_output_group = num_output_group; - model.random_forest_flag = random_forest_flag; - model.param = param; - return model; -} - -template -inline ModelImpl& -Model::GetImpl() { - return *static_cast*>(handle_.get()); -} - -template -inline const ModelImpl& -Model::GetImpl() const { - return *static_cast*>(handle_.get()); -} - inline ModelType Model::GetModelType() const { return type_; @@ -872,17 +805,16 @@ Model::InferModelTypeOf() { } template -inline Model +inline std::unique_ptr Model::Create() { - Model model; - model.handle_.reset(new ModelImpl()); - model.type_ = InferModelTypeOf(); - model.threshold_type_ = InferTypeInfoOf(); - model.leaf_output_type_ = InferTypeInfoOf(); + std::unique_ptr model = std::make_unique>(); + model->type_ = InferModelTypeOf(); + model->threshold_type_ = InferTypeInfoOf(); + model->leaf_output_type_ = InferTypeInfoOf(); return model; } -inline Model +inline std::unique_ptr Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { auto error_threshold_type = [threshold_type]() { std::ostringstream oss; @@ -922,7 +854,7 @@ Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { throw std::runtime_error(error_threshold_type()); break; } - return treelite::Model(); // avoid missing return value warning + return std::unique_ptr(nullptr); // avoid missing return value warning } template @@ -930,17 +862,18 @@ inline auto Model::Dispatch(Func func) { switch (type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat32ThresholdFloat32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat64ThresholdUInt32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat64ThresholdFloat64LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); default: throw std::runtime_error(std::string("Unknown type detected: ") + std::to_string(static_cast(type_))); - return func(GetImpl()); // avoid "missing return" warning + return func(*dynamic_cast*>(this)); + // avoid "missing return" warning } } @@ -949,65 +882,31 @@ inline auto Model::Dispatch(Func func) const { switch (type_) { case ModelType::kFloat32ThresholdUInt32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat32ThresholdFloat32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat64ThresholdUInt32LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); case ModelType::kFloat64ThresholdFloat64LeafOutput: - return func(GetImpl()); + return func(*dynamic_cast*>(this)); default: throw std::runtime_error(std::string("Unknown type detected: ") + std::to_string(static_cast(type_))); - return func(GetImpl()); // avoid "missing return" warning + return func(*dynamic_cast*>(this)); + // avoid "missing return" warning } } -inline ModelParam -Model::GetParam() const { - return Dispatch([](const auto& handle) { return handle.param; }); -} - -inline int -Model::GetNumFeature() const { - return Dispatch([](const auto& handle) { return handle.num_feature; }); -} - -inline int -Model::GetNumOutputGroup() const { - return Dispatch([](const auto& handle) { return handle.num_output_group; }); -} - -inline bool -Model::GetRandomForestFlag() const { - return Dispatch([](const auto& handle) { return handle.random_forest_flag; }); -} - -inline size_t -Model::GetNumTree() const { - return Dispatch([](const auto& handle) { return handle.trees.size(); }); -} - -inline void -Model::SetTreeLimit(size_t limit) { - Dispatch([limit](auto& handle) { handle.trees.resize(limit); }); -} - -inline void -Model::ReferenceSerialize(dmlc::Stream* fo) const { - Dispatch([fo](const auto& handle) { handle.ReferenceSerialize(fo); }); -} - inline std::vector Model::GetPyBuffer() { std::vector buffer; buffer.push_back(GetPyBufferFromScalar(&threshold_type_)); buffer.push_back(GetPyBufferFromScalar(&leaf_output_type_)); - Dispatch([&buffer](auto& handle) { return handle.GetPyBuffer(&buffer); }); + this->GetPyBuffer(&buffer); return buffer; } -inline Model +inline std::unique_ptr Model::CreateFromPyBuffer(std::vector frames) { using TypeInfoInt = std::underlying_type::type; TypeInfo threshold_type, leaf_output_type; @@ -1017,13 +916,104 @@ Model::CreateFromPyBuffer(std::vector frames) { InitScalarFromPyBuffer(&threshold_type, frames[0]); InitScalarFromPyBuffer(&leaf_output_type, frames[1]); - Model model = Model::Create(threshold_type, leaf_output_type); - model.Dispatch([&frames](auto& handle) { - handle.InitFromPyBuffer(frames.begin() + 2, frames.end()); - }); + std::unique_ptr model = Model::Create(threshold_type, leaf_output_type); + model->InitFromPyBuffer(frames.begin() + 2, frames.end()); return model; } + +template +inline void +ModelImpl::GetPyBuffer(std::vector* dest) { + /* Header */ + dest->push_back(GetPyBufferFromScalar(&num_feature)); + dest->push_back(GetPyBufferFromScalar(&num_output_group)); + dest->push_back(GetPyBufferFromScalar(&random_forest_flag)); + dest->push_back(GetPyBufferFromScalar( + ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); + + /* Body */ + for (Tree& tree : trees) { + tree.GetPyBuffer(dest); + } +} + +template +inline void +ModelImpl::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { + const size_t num_frame = std::distance(begin, end); + /* Header */ + constexpr size_t kNumFrameInHeader = 4; + if (num_frame < kNumFrameInHeader) { + throw std::runtime_error("Wrong number of frames"); + } + InitScalarFromPyBuffer(&num_feature, *begin++); + InitScalarFromPyBuffer(&num_output_group, *begin++); + InitScalarFromPyBuffer(&random_forest_flag, *begin++); + InitScalarFromPyBuffer(¶m, *begin++); + /* Body */ + if ((num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) { + throw std::runtime_error("Wrong number of frames"); + } + trees.clear(); + for (; begin < end; begin += kNumFramePerTree) { + trees.emplace_back(); + trees.back().InitFromPyBuffer(begin, begin + kNumFramePerTree); + } +} + + +template +inline ModelImpl +ModelImpl::Clone() const { + ModelImpl model; + for (const Tree& t : trees) { + model.trees.push_back(t.Clone()); + } + model.num_feature = num_feature; + model.num_output_group = num_output_group; + model.random_forest_flag = random_forest_flag; + model.param = param; + return model; +} + +template +inline ModelParam +ModelImpl::GetParam() const { + return param; +} + +template +inline int +ModelImpl::GetNumFeature() const { + return num_feature; +} + +template +inline int +ModelImpl::GetNumOutputGroup() const { + return num_output_group; +} + +template +inline bool +ModelImpl::GetRandomForestFlag() const { + return random_forest_flag; +} + +template +inline size_t +ModelImpl::GetNumTree() const { + return trees.size(); +} + +template +inline void +ModelImpl::SetTreeLimit(size_t limit) { + trees.resize(limit); +} + inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg) { auto unknown = param->InitAllowUnknown(cfg); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 531a367c..3f6033bd 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -142,29 +142,23 @@ int TreeliteCompilerFree(CompilerHandle handle) { API_END(); } -int TreeliteLoadLightGBMModel(const char* filename, - ModelHandle* out) { +int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadLightGBMModel(filename, model.get()); + std::unique_ptr model = frontend::LoadLightGBMModel(filename); *out = static_cast(model.release()); API_END(); } -int TreeliteLoadXGBoostModel(const char* filename, - ModelHandle* out) { +int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadXGBoostModel(filename, model.get()); + std::unique_ptr model = frontend::LoadXGBoostModel(filename); *out = static_cast(model.release()); API_END(); } -int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, - ModelHandle* out) { +int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadXGBoostModel(buf, len, model.get()); + std::unique_ptr model = frontend::LoadXGBoostModel(buf, len); *out = static_cast(model.release()); API_END(); } @@ -322,8 +316,7 @@ int TreeliteCreateModelBuilder( API_END(); } -int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, - const char* name, +int TreeliteModelBuilderSetModelParam(ModelBuilderHandle handle, const char* name, const char* value) { API_BEGIN(); auto* builder = static_cast(handle); @@ -338,8 +331,7 @@ int TreeliteDeleteModelBuilder(ModelBuilderHandle handle) { API_END(); } -int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, - TreeBuilderHandle tree_builder_handle, +int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, TreeBuilderHandle tree_builder_handle, int index) { API_BEGIN(); auto* model_builder = static_cast(handle); @@ -350,8 +342,7 @@ int TreeliteModelBuilderInsertTree(ModelBuilderHandle handle, API_END(); } -int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, - TreeBuilderHandle *out) { +int TreeliteModelBuilderGetTree(ModelBuilderHandle handle, int index, TreeBuilderHandle *out) { API_BEGIN(); auto* model_builder = static_cast(handle); CHECK(model_builder) << "Detected dangling reference to deleted ModelBuilder object"; @@ -369,13 +360,11 @@ int TreeliteModelBuilderDeleteTree(ModelBuilderHandle handle, int index) { API_END(); } -int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, - ModelHandle* out) { +int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, ModelHandle* out) { API_BEGIN(); auto* builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object"; - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); *out = static_cast(model.release()); API_END(); } diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index c09b88e1..8d301885 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -267,7 +267,7 @@ class FailSafeCompiler : public Compiler { CompiledModel Compile(const Model& model) override { CHECK(model.GetModelType() == ModelType::kFloat32ThresholdFloat32LeafOutput) << "Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs"; - const ModelImpl& model_handle = model.GetImpl(); + const auto& model_handle = dynamic_cast&>(model); CompiledModel cm; cm.backend = "native"; diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index fa838093..bb0e724e 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -401,13 +401,13 @@ ModelBuilder::DeleteTree(int index) { trees.erase(trees.begin() + index); } -void -ModelBuilder::CommitModel(Model* out_model) { - Model& model = *out_model; - model = Model::Create(pimpl_->threshold_type, pimpl_->leaf_output_type); - model.Dispatch([this](auto& model_handle) { +std::unique_ptr +ModelBuilder::CommitModel() { + std::unique_ptr model = Model::Create(pimpl_->threshold_type, pimpl_->leaf_output_type); + model->Dispatch([this](auto& model_handle) { this->pimpl_->CommitModelImpl(&model_handle); }); + return model; } template diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index d94e3775..145a8a3b 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -14,7 +14,7 @@ namespace { -treelite::Model ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(dmlc::Stream* fi); } // anonymous namespace @@ -23,9 +23,9 @@ namespace frontend { DMLC_REGISTRY_FILE_TAG(lightgbm); -void LoadLightGBMModel(const char *filename, Model* out) { +std::unique_ptr LoadLightGBMModel(const char *filename) { std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - *out = std::move(ParseStream(fi.get())); + return ParseStream(fi.get()); } } // namespace frontend @@ -253,7 +253,7 @@ inline std::vector LoadText(dmlc::Stream* fi) { return lines; } -inline treelite::Model ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(dmlc::Stream* fi) { std::vector lgb_trees_; int max_feature_idx_; int num_tree_per_iteration_; @@ -436,18 +436,18 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } /* 2. Export model */ - treelite::Model model_wrapper = treelite::Model::Create(); - treelite::ModelImpl& model_handle = model_wrapper.GetImpl(); - model_handle.num_feature = max_feature_idx_ + 1; - model_handle.num_output_group = num_tree_per_iteration_; - if (model_handle.num_output_group > 1) { + std::unique_ptr model = treelite::Model::Create(); + auto* model_handle = dynamic_cast*>(model.get()); + model_handle->num_feature = max_feature_idx_ + 1; + model_handle->num_output_group = num_tree_per_iteration_; + if (model_handle->num_output_group > 1) { // multiclass classification with gradient boosted trees CHECK(!average_output_) << "Ill-formed LightGBM model file: cannot use random forest mode " << "for multi-class classification"; - model_handle.random_forest_flag = false; + model_handle->random_forest_flag = false; } else { - model_handle.random_forest_flag = average_output_; + model_handle->random_forest_flag = average_output_; } // set correct prediction transform function, depending on objective function @@ -463,11 +463,11 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { break; } } - CHECK(num_class >= 0 && num_class == model_handle.num_output_group) + CHECK(num_class >= 0 && num_class == model_handle->num_output_group) << "Ill-formed LightGBM model file: not a valid multiclass objective"; - std::strncpy(model_handle.param.pred_transform, "softmax", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "softmax", + sizeof(model_handle->param.pred_transform)); } else if (obj_name_ == "multiclassova") { // validate num_class and alpha parameters int num_class = -1; @@ -486,13 +486,13 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - CHECK(num_class >= 0 && num_class == model_handle.num_output_group + CHECK(num_class >= 0 && num_class == model_handle->num_output_group && alpha > 0.0f) << "Ill-formed LightGBM model file: not a valid multiclassova objective"; - std::strncpy(model_handle.param.pred_transform, "multiclass_ova", - sizeof(model_handle.param.pred_transform)); - model_handle.param.sigmoid_alpha = alpha; + std::strncpy(model_handle->param.pred_transform, "multiclass_ova", + sizeof(model_handle->param.pred_transform)); + model_handle->param.sigmoid_alpha = alpha; } else if (obj_name_ == "binary") { // validate alpha parameter float alpha = -1.0f; @@ -508,25 +508,25 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { CHECK_GT(alpha, 0.0f) << "Ill-formed LightGBM model file: not a valid binary objective"; - std::strncpy(model_handle.param.pred_transform, "sigmoid", - sizeof(model_handle.param.pred_transform)); - model_handle.param.sigmoid_alpha = alpha; + std::strncpy(model_handle->param.pred_transform, "sigmoid", + sizeof(model_handle->param.pred_transform)); + model_handle->param.sigmoid_alpha = alpha; } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") { - std::strncpy(model_handle.param.pred_transform, "sigmoid", - sizeof(model_handle.param.pred_transform)); - model_handle.param.sigmoid_alpha = 1.0f; + std::strncpy(model_handle->param.pred_transform, "sigmoid", + sizeof(model_handle->param.pred_transform)); + model_handle->param.sigmoid_alpha = 1.0f; } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") { - std::strncpy(model_handle.param.pred_transform, "logarithm_one_plus_exp", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "logarithm_one_plus_exp", + sizeof(model_handle->param.pred_transform)); } else { - std::strncpy(model_handle.param.pred_transform, "identity", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "identity", + sizeof(model_handle->param.pred_transform)); } // traverse trees for (const auto& lgb_tree : lgb_trees_) { - model_handle.trees.emplace_back(); - treelite::Tree& tree = model_handle.trees.back(); + model_handle->trees.emplace_back(); + treelite::Tree& tree = model_handle->trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -588,7 +588,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - return model_wrapper; + return model; } } // anonymous namespace diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc index 2b957f07..9c31d5fb 100644 --- a/src/frontend/xgboost.cc +++ b/src/frontend/xgboost.cc @@ -16,7 +16,7 @@ namespace { -treelite::Model ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(dmlc::Stream* fi); } // anonymous namespace @@ -25,14 +25,14 @@ namespace frontend { DMLC_REGISTRY_FILE_TAG(xgboost); -void LoadXGBoostModel(const char* filename, Model* out) { +std::unique_ptr LoadXGBoostModel(const char* filename) { std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - *out = std::move(ParseStream(fi.get())); + return ParseStream(fi.get()); } -void LoadXGBoostModel(const void* buf, size_t len, Model* out) { +std::unique_ptr LoadXGBoostModel(const void* buf, size_t len) { dmlc::MemoryFixedSizeStream fs(const_cast(buf), len); - *out = std::move(ParseStream(&fs)); + return ParseStream(&fs); } } // namespace frontend @@ -330,7 +330,7 @@ class XGBTree { } }; -inline treelite::Model ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(dmlc::Stream* fi) { std::vector xgb_trees_; LearnerModelParam mparam_; // model parameter GBTreeModelParam gbm_param_; // GBTree training parameter @@ -392,50 +392,50 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { bool need_transform_to_margin = mparam_.major_version >= 1; /* 2. Export model */ - treelite::Model model_wrapper = treelite::Model::Create(); - treelite::ModelImpl& model_handle = model_wrapper.GetImpl(); - model_handle.num_feature = static_cast(mparam_.num_feature); - model_handle.num_output_group = std::max(mparam_.num_class, 1); - model_handle.random_forest_flag = false; + std::unique_ptr model = treelite::Model::Create(); + auto* model_handle = dynamic_cast*>(model.get()); + model_handle->num_feature = static_cast(mparam_.num_feature); + model_handle->num_output_group = std::max(mparam_.num_class, 1); + model_handle->random_forest_flag = false; // set global bias - model_handle.param.global_bias = static_cast(mparam_.base_score); + model_handle->param.global_bias = static_cast(mparam_.base_score); std::vector exponential_family { "count:poisson", "reg:gamma", "reg:tweedie" }; if (need_transform_to_margin) { if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - model_handle.param.global_bias = ProbToMargin::Sigmoid(model_handle.param.global_bias); + model_handle->param.global_bias = ProbToMargin::Sigmoid(model_handle->param.global_bias); } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - model_handle.param.global_bias = ProbToMargin::Exponential(model_handle.param.global_bias); + model_handle->param.global_bias = ProbToMargin::Exponential(model_handle->param.global_bias); } } // set correct prediction transform function, depending on objective function if (name_obj_ == "multi:softmax") { - std::strncpy(model_handle.param.pred_transform, "max_index", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "max_index", + sizeof(model_handle->param.pred_transform)); } else if (name_obj_ == "multi:softprob") { - std::strncpy(model_handle.param.pred_transform, "softmax", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "softmax", + sizeof(model_handle->param.pred_transform)); } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - std::strncpy(model_handle.param.pred_transform, "sigmoid", - sizeof(model_handle.param.pred_transform)); - model_handle.param.sigmoid_alpha = 1.0f; + std::strncpy(model_handle->param.pred_transform, "sigmoid", + sizeof(model_handle->param.pred_transform)); + model_handle->param.sigmoid_alpha = 1.0f; } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - std::strncpy(model_handle.param.pred_transform, "exponential", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "exponential", + sizeof(model_handle->param.pred_transform)); } else { - std::strncpy(model_handle.param.pred_transform, "identity", - sizeof(model_handle.param.pred_transform)); + std::strncpy(model_handle->param.pred_transform, "identity", + sizeof(model_handle->param.pred_transform)); } // traverse trees for (const auto& xgb_tree : xgb_trees_) { - model_handle.trees.emplace_back(); - treelite::Tree& tree = model_handle.trees.back(); + model_handle->trees.emplace_back(); + treelite::Tree& tree = model_handle->trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -463,7 +463,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { tree.SetSumHess(new_id, stat.sum_hess); } } - return model_wrapper; + return model; } } // anonymous namespace diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 662b981e..60b9735b 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -24,9 +24,9 @@ inline std::string TreeliteToBytes(treelite::Model* model) { inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); - treelite::Model received_model = treelite::Model::CreateFromPyBuffer(buffer); + std::unique_ptr received_model = treelite::Model::CreateFromPyBuffer(buffer); - ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(&received_model)); + ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); } } // anonymous namespace @@ -52,8 +52,7 @@ void PyBufferInterfaceRoundTrip_TreeStump() { tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -90,8 +89,7 @@ void PyBufferInterfaceRoundTrip_TreeStumpLeafVec() { frontend::Value::Create(-1)}); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -131,8 +129,7 @@ void PyBufferInterfaceRoundTrip_TreeStumpCategoricalSplit() { tree->SetLeafNode(2, frontend::Value::Create(1)); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -180,8 +177,7 @@ void PyBufferInterfaceRoundTrip_TreeDepth2() { builder->InsertTree(tree.get()); } - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -229,8 +225,7 @@ void PyBufferInterfaceRoundTrip_DeepFullTree() { tree->SetRootNode(0); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } From e8f874f2456d91f4eef7818e5bae4ce4133b7cdd Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 2 Sep 2020 01:10:04 -0700 Subject: [PATCH 14/38] Add functions to query data types used in a compiled predictor --- src/compiler/ast_native.cc | 8 ++++++++ src/compiler/native/header_template.h | 2 ++ src/compiler/native/main_template.h | 8 ++++++++ 3 files changed, 18 insertions(+) diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index dcae797b..d5a704ad 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -208,6 +208,8 @@ class ASTNativeCompiler : public Compiler { leaf_output_type) : fmt::format("{} predict(union Entry* data, int pred_margin)", leaf_output_type); + const char* get_threshold_type_signature = "const char* get_threshold_type(void)"; + const char* get_leaf_output_type_signature = "const char* get_leaf_output_type(void)"; if (!array_is_categorical_.empty()) { array_is_categorical_ @@ -230,6 +232,10 @@ class ASTNativeCompiler : public Compiler { = get_global_bias_function_signature, "pred_transform_function"_a = pred_tranform_func_, "predict_function_signature"_a = predict_function_signature, + "get_threshold_type_signature"_a = get_threshold_type_signature, + "threshold_type_str"_a = TypeInfoToString(InferTypeInfoOf()), + "get_leaf_output_type_signature"_a = get_leaf_output_type_signature, + "leaf_output_type_str"_a = TypeInfoToString(InferTypeInfoOf()), "num_output_group"_a = num_output_group_, "num_feature"_a = num_feature_, "pred_transform"_a = pred_transform_, @@ -250,6 +256,8 @@ class ASTNativeCompiler : public Compiler { "get_global_bias_function_signature"_a = get_global_bias_function_signature, "predict_function_signature"_a = predict_function_signature, + "get_threshold_type_signature"_a = get_threshold_type_signature, + "get_leaf_output_type_signature"_a = get_leaf_output_type_signature, "threshold_type"_a = threshold_type, "threshold_type_Node"_a = (param.quantize > 0 ? std::string("int") : threshold_type)), indent); diff --git a/src/compiler/native/header_template.h b/src/compiler/native/header_template.h index a822f54f..7d9158aa 100644 --- a/src/compiler/native/header_template.h +++ b/src/compiler/native/header_template.h @@ -49,6 +49,8 @@ extern const unsigned char is_categorical[]; {dllexport}{get_sigmoid_alpha_function_signature}; {dllexport}{get_global_bias_function_signature}; {dllexport}{predict_function_signature}; +{dllexport}{get_threshold_type_signature}; +{dllexport}{get_leaf_output_type_signature}; )TREELITETEMPLATE"; } // namespace native diff --git a/src/compiler/native/main_template.h b/src/compiler/native/main_template.h index a3176e65..fdb747d4 100644 --- a/src/compiler/native/main_template.h +++ b/src/compiler/native/main_template.h @@ -38,6 +38,14 @@ R"TREELITETEMPLATE( return {global_bias}; }} +{get_threshold_type_signature} {{ + return "{threshold_type_str}"; +}} + +{get_leaf_output_type_signature} {{ + return "{leaf_output_type_str}"; +}} + {pred_transform_function} {predict_function_signature} {{ )TREELITETEMPLATE"; From 71d43aa956c1fe34074db1850085ab653cb8b3e1 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 2 Sep 2020 01:59:35 -0700 Subject: [PATCH 15/38] Remove single-instance prediction feature from the runtime --- include/treelite/c_api_runtime.h | 19 ----- include/treelite/predictor.h | 13 ---- runtime/python/treelite_runtime/predictor.py | 74 -------------------- src/c_api/c_api_runtime.cc | 11 --- src/predictor/predictor.cc | 33 +-------- src/predictor/thread_pool/thread_pool.h | 3 +- 6 files changed, 3 insertions(+), 150 deletions(-) diff --git a/include/treelite/c_api_runtime.h b/include/treelite/c_api_runtime.h index cbd268ad..066e03b0 100644 --- a/include/treelite/c_api_runtime.h +++ b/include/treelite/c_api_runtime.h @@ -123,25 +123,6 @@ TREELITE_DLL int TreelitePredictorPredictBatch(PredictorHandle handle, int pred_margin, float* out_result, size_t* out_result_size); - -/*! - * \brief Make predictions on a single data row (synchronously). The work - * will be scheduled to the calling thread. - * \param handle predictor - * \param inst single data row - * \param pred_margin whether to produce raw margin scores instead of - * transformed probabilities - * \param out_result resulting output vector; use - * TreelitePredictorQueryResultSizeSingleInst() to allocate sufficient space - * \param out_result_size used to save length of the output vector, which is - * guaranteed to be at most TreelitePredictorQueryResultSizeSingleInst() - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreelitePredictorPredictInst(PredictorHandle handle, - union TreelitePredictorEntry* inst, - int pred_margin, float* out_result, - size_t* out_result_size); - /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 9f668cb8..40652a1b 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -78,19 +78,6 @@ class Predictor { bool pred_margin, float* out_result); size_t PredictBatch(const DenseBatch* batch, int verbose, bool pred_margin, float* out_result); - /*! - * \brief Make predictions on a single data row (synchronously). The work - * will be scheduled to the calling thread. - * \param inst single data row - * \param pred_margin whether to produce raw margin scores instead of - * transformed probabilities - * \param out_result resulting output vector; use - * QueryResultSizeSingleInst() to allocate sufficient space - * \return length of the output vector, which is guaranteed to be less than - * or equal to QueryResultSizeSingleInst() - */ - size_t PredictInst(TreelitePredictorEntry* inst, bool pred_margin, - float* out_result); /*! * \brief Given a batch of data rows, query the necessary size of array to diff --git a/runtime/python/treelite_runtime/predictor.py b/runtime/python/treelite_runtime/predictor.py index 4d15a644..ff24e6af 100644 --- a/runtime/python/treelite_runtime/predictor.py +++ b/runtime/python/treelite_runtime/predictor.py @@ -309,80 +309,6 @@ def __init__(self, libpath, nthread=None, verbose=False): log_info(__file__, lineno(), f'Dynamic shared library {path} has been successfully loaded into memory') - def predict_instance(self, inst, missing=None, pred_margin=False): - """ - Perform single-instance prediction. Prediction is run by the calling thread. - - Parameters - ---------- - inst: :py:class:`numpy.ndarray` / :py:class:`scipy.sparse.csr_matrix` /\ - :py:class:`dict ` - Data instance for which a prediction will be made. If ``inst`` is of - type :py:class:`scipy.sparse.csr_matrix`, its first dimension must be 1 - (``shape[0]==1``). If ``inst`` is of type :py:class:`numpy.ndarray`, - it must be one-dimensional. If ``inst`` is of type - :py:class:`dict `, it must be a dictionary where the keys - indicate feature indices (0-based) and the values corresponding - feature values. - missing : :py:class:`float `, optional - Value in the data instance that represents a missing value. If set to - ``None``, ``numpy.nan`` will be used. Only applicable if ``inst`` is - of type :py:class:`numpy.ndarray`. - pred_margin: :py:class:`bool `, optional - Whether to produce raw margins rather than transformed probabilities - """ - entry = (PredictorEntry * self.num_feature_)() - for i in range(self.num_feature_): - entry[i].missing = -1 - - if isinstance(inst, scipy.sparse.csr_matrix): - if inst.shape[0] != 1: - raise ValueError('inst cannot have more than one row') - if inst.shape[1] > self.num_feature_: - raise ValueError('Too many features. This model was trained with only ' + - f'{self.num_feature_} features') - for i in range(inst.nnz): - entry[inst.indices[i]].fvalue = inst.data[i] - elif isinstance(inst, scipy.sparse.csc_matrix): - raise TypeError('inst must be csr_matrix') - elif isinstance(inst, np.ndarray): - if len(inst.shape) > 1: - raise ValueError('inst must be 1D') - if inst.shape[0] > self.num_feature_: - raise ValueError('Too many features. This model was trained with only ' + - f'{self.num_feature_} features') - if missing is None or np.isnan(missing): - for i in range(inst.shape[0]): - if not np.isnan(inst[i]): - entry[i].fvalue = inst[i] - else: - for i in range(inst.shape[0]): - if inst[i] != missing: - entry[i].fvalue = inst[i] - elif isinstance(inst, dict): - for k, v in inst.items(): - entry[k].fvalue = v - else: - raise TypeError('inst must be NumPy array, SciPy CSR matrix, or a dictionary') - - result_size = ctypes.c_size_t() - _check_call(_LIB.TreelitePredictorQueryResultSizeSingleInst( - self.handle, - ctypes.byref(result_size))) - out_result = np.zeros(result_size.value, dtype=np.float32, order='C') - out_result_size = ctypes.c_size_t() - _check_call(_LIB.TreelitePredictorPredictInst( - self.handle, - ctypes.byref(entry), - ctypes.c_int(1 if pred_margin else 0), - out_result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - ctypes.byref(out_result_size))) - idx = int(out_result_size.value) - res = out_result[0:idx].reshape((1, -1)).squeeze() - if self.num_output_group_ > 1: - res = res.reshape((-1, self.num_output_group_)) - return res - def predict(self, batch, verbose=False, pred_margin=False): """ Perform batch prediction with a 2D sparse data matrix. Worker threads will diff --git a/src/c_api/c_api_runtime.cc b/src/c_api/c_api_runtime.cc index 0cb55ac6..0834dc10 100644 --- a/src/c_api/c_api_runtime.cc +++ b/src/c_api/c_api_runtime.cc @@ -124,17 +124,6 @@ int TreelitePredictorPredictBatch(PredictorHandle handle, API_END(); } -int TreelitePredictorPredictInst(PredictorHandle handle, - union TreelitePredictorEntry* inst, - int pred_margin, - float* out_result, size_t* out_result_size) { - API_BEGIN(); - Predictor* predictor_ = static_cast(handle); - *out_result_size - = predictor_->PredictInst(inst, (pred_margin != 0), out_result); - API_END(); -} - int TreelitePredictorQueryResultSize(PredictorHandle handle, void* batch, int batch_sparse, diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index b8739386..bc49e395 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -189,26 +189,6 @@ inline size_t PredictBatch_(const BatchType* batch, bool pred_margin, return query_result_size; } -inline size_t PredictInst_(TreelitePredictorEntry* inst, - bool pred_margin, size_t num_output_group, - treelite::Predictor::PredFuncHandle pred_func_handle, - size_t expected_query_result_size, float* out_pred) { - CHECK(pred_func_handle != nullptr) - << "A shared library needs to be loaded first using Load()"; - size_t query_result_size; // Dimention of output vector - if (num_output_group > 1) { // multi-class classification task - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = pred_func(inst, static_cast(pred_margin), out_pred); - } else { // every other task - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - out_pred[0] = pred_func(inst, static_cast(pred_margin)); - query_result_size = 1; - } - return query_result_size; -} - } // anonymous namespace namespace treelite { @@ -437,8 +417,7 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, } const double tend = dmlc::GetTime(); if (verbose > 0) { - LOG(INFO) << "Treelite: Finished prediction in " - << tend - tstart << " sec"; + LOG(INFO) << "Treelite: Finished prediction in " << tend - tstart << " sec"; } return total_size; } @@ -455,14 +434,4 @@ Predictor::PredictBatch(const DenseBatch* batch, int verbose, return PredictBatchBase_(batch, verbose, pred_margin, out_result); } -size_t -Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin, - float* out_result) { - size_t total_size; - total_size = PredictInst_(inst, pred_margin, num_output_group_, - pred_func_handle_, - QueryResultSizeSingleInst(), out_result); - return total_size; -} - } // namespace treelite diff --git a/src/predictor/thread_pool/thread_pool.h b/src/predictor/thread_pool/thread_pool.h index c58ed4a7..a356a441 100644 --- a/src/predictor/thread_pool/thread_pool.h +++ b/src/predictor/thread_pool/thread_pool.h @@ -7,6 +7,7 @@ #ifndef TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ #define TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ +#include #include #include #include @@ -42,7 +43,7 @@ class ThreadPool { } /* bind threads to cores */ const char* bind_flag = getenv("TREELITE_BIND_THREADS"); - if (bind_flag == nullptr || std::atoi(bind_flag) == 1) { + if (bind_flag == nullptr || std::stoi(bind_flag) == 1) { SetAffinity(); } } From 914469bff77aa04136101eed208eb3521a89a266 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Wed, 2 Sep 2020 19:36:51 -0700 Subject: [PATCH 16/38] Refactor the runtime API to use DMatrix --- include/treelite/c_api_runtime.h | 130 ++---- include/treelite/data.h | 39 +- include/treelite/entry.h | 20 - include/treelite/predictor.h | 207 +++++---- include/treelite/tree_impl.h | 48 +- include/treelite/typeinfo.h | 57 +++ src/annotator.cc | 2 +- src/c_api/c_api_runtime.cc | 151 ++----- src/data/data.cc | 25 +- src/predictor/predictor.cc | 568 +++++++++++++----------- src/predictor/thread_pool/spsc_queue.h | 6 + src/predictor/thread_pool/thread_pool.h | 2 + 12 files changed, 598 insertions(+), 657 deletions(-) delete mode 100644 include/treelite/entry.h diff --git a/include/treelite/c_api_runtime.h b/include/treelite/c_api_runtime.h index 066e03b0..56d7de93 100644 --- a/include/treelite/c_api_runtime.h +++ b/include/treelite/c_api_runtime.h @@ -13,7 +13,6 @@ #define TREELITE_C_API_RUNTIME_H_ #include "c_api_common.h" -#include "entry.h" /*! * \addtogroup opaque_handles @@ -22,10 +21,9 @@ */ /*! \brief handle to predictor class */ typedef void* PredictorHandle; -/*! \brief handle to batch of sparse data rows */ -typedef void* CSRBatchHandle; -/*! \brief handle to batch of dense data rows */ -typedef void* DenseBatchHandle; +/*! \brief handle to output from predictor */ +typedef void* PredictorOutputHandle; + /*! \} */ /*! @@ -33,59 +31,6 @@ typedef void* DenseBatchHandle; * Predictor interface * \{ */ -/*! - * \brief assemble a sparse batch - * \param data feature values - * \param col_ind feature indices - * \param row_ptr pointer to row headers - * \param num_row number of data rows in the batch - * \param num_col number of columns (features) in the batch - * \param out handle to sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteAssembleSparseBatch(const float* data, - const uint32_t* col_ind, - const size_t* row_ptr, - size_t num_row, size_t num_col, - CSRBatchHandle* out); -/*! - * \brief delete a sparse batch from memory - * \param handle sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDeleteSparseBatch(CSRBatchHandle handle); -/*! - * \brief assemble a dense batch - * \param data feature values - * \param missing_value value to represent the missing value - * \param num_row number of data rows in the batch - * \param num_col number of columns (features) in the batch - * \param out handle to sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteAssembleDenseBatch(const float* data, - float missing_value, - size_t num_row, size_t num_col, - DenseBatchHandle* out); -/*! - * \brief delete a dense batch from memory - * \param handle dense batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDeleteDenseBatch(DenseBatchHandle handle); - -/*! - * \brief get dimensions of a batch - * \param handle a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether the batch is sparse (true) or dense (false) - * \param out_num_row used to set number of rows - * \param out_num_col used to set number of columns - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteBatchGetDimension(void* handle, - int batch_sparse, - size_t* out_num_row, - size_t* out_num_col); /*! * \brief load prediction code into memory. @@ -96,56 +41,52 @@ TREELITE_DLL int TreeliteBatchGetDimension(void* handle, * \param out handle to predictor * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorLoad(const char* library_path, - int num_worker_thread, - PredictorHandle* out); +TREELITE_DLL int TreelitePredictorLoad( + const char* library_path, int num_worker_thread, PredictorHandle* out); /*! * \brief Make predictions on a batch of data rows (synchronously). This * function internally divides the workload among all worker threads. * \param handle predictor - * \param batch a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether batch is sparse (1) or dense (0) + * \param batch the data matrix containing a batch of rows * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param out_result resulting output vector; use - * TreelitePredictorQueryResultSize() to allocate sufficient - * space + * \param output_buffer resulting output vector; use + * TreelitePredictorQueryResultSize() to allocate sufficient space * \param out_result_size used to save length of the output vector, * which is guaranteed to be less than or equal to * TreelitePredictorQueryResultSize() * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorPredictBatch(PredictorHandle handle, - void* batch, - int batch_sparse, - int verbose, - int pred_margin, - float* out_result, - size_t* out_result_size); +TREELITE_DLL int TreelitePredictorPredictBatch( + PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, + PredictorOutputHandle output_buffer, size_t* out_result_size); /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. + * \brief Allocate a buffer space that's sufficient to hold predicton for a given data matrix. + * The size of the buffer is given by TreelitePredictorQueryResultSize(). * \param handle predictor - * \param batch a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether batch is sparse (1) or dense (0) - * \param out used to store the length of prediction array + * \param batch the data matrix containing a batch of rows + * \param out_output_buffer Newly allocated buffer space + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreelitePredictorAllocateOutputBuffer( + PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_buffer); +/*! + * \brief Delete a buffer space from memory + * \param handle the buffer space to delete * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryResultSize(PredictorHandle handle, - void* batch, - int batch_sparse, - size_t* out); +TREELITE_DLL int TreelitePredictorDeleteOutputBuffer(PredictorOutputHandle handle); /*! - * \brief Query the necessary size of array to hold the prediction for a - * single data row + * \brief Given a batch of data rows, query the necessary size of array to + * hold predictions for all data points. * \param handle predictor + * \param batch the data matrix containing a batch of rows * \param out used to store the length of prediction array * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryResultSizeSingleInst( - PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryResultSize( + PredictorHandle handle, DMatrixHandle batch, size_t* out); /*! * \brief Get the number of output groups in the loaded model * The number is 1 for most tasks; @@ -154,8 +95,7 @@ TREELITE_DLL int TreelitePredictorQueryResultSizeSingleInst( * \param out length of prediction array * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t* out); /*! * \brief Get the width (number of features) of each instance used to train * the loaded model @@ -163,8 +103,7 @@ TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, * \param out number of features * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t* out); /*! * \brief Get name of post prediction transformation used to train @@ -173,8 +112,7 @@ TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, * \param out name of post prediction transformation * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, - const char** out); +TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out); /*! * \brief Get alpha value of sigmoid transformation used to train * the loaded model @@ -182,8 +120,7 @@ TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, * \param out alpha value of sigmoid transformation * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, - float* out); +TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float* out); /*! * \brief Get global bias which adjusting predicted margin scores @@ -191,8 +128,7 @@ TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, * \param out global bias value * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryGlobalBias(PredictorHandle handle, - float* out); +TREELITE_DLL int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out); /*! * \brief delete predictor from memory * \param handle predictor to remove diff --git a/include/treelite/data.h b/include/treelite/data.h index d32078c7..e5f34409 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -15,11 +15,18 @@ namespace treelite { +enum class DMatrixType : uint8_t { + kDense = 0, + kSparseCSR = 1 +}; + class DMatrix { public: virtual size_t GetNumRow() const = 0; virtual size_t GetNumCol() const = 0; virtual size_t GetNumElem() const = 0; + virtual DMatrixType GetType() const = 0; + virtual TypeInfo GetElementType() const = 0; DMatrix() = default; virtual ~DMatrix() = default; }; @@ -39,11 +46,13 @@ class DenseDMatrix : public DMatrix { size_t GetNumRow() const override = 0; size_t GetNumCol() const override = 0; size_t GetNumElem() const override = 0; + DMatrixType GetType() const override = 0; + TypeInfo GetElementType() const override; }; template class DenseDMatrixImpl : public DenseDMatrix { - private: + public: /*! \brief feature values */ std::vector data; /*! \brief value representing the missing value (usually NaN) */ @@ -52,7 +61,7 @@ class DenseDMatrixImpl : public DenseDMatrix { size_t num_row; /*! \brief number of columns (i.e. # of features used) */ size_t num_col; - public: + DenseDMatrixImpl() = delete; DenseDMatrixImpl(std::vector data, ElementType missing_value, size_t num_row, size_t num_col); @@ -65,6 +74,7 @@ class DenseDMatrixImpl : public DenseDMatrix { size_t GetNumRow() const override; size_t GetNumCol() const override; size_t GetNumElem() const override; + DMatrixType GetType() const override; friend class DenseDMatrix; static_assert(std::is_same::value || std::is_same::value, @@ -88,12 +98,11 @@ class CSRDMatrix : public DMatrix { size_t num_row, size_t num_col); static std::unique_ptr Create( const char* filename, const char* format, int nthread, int verbose); - TypeInfo GetMatrixType() const; - template - inline auto Dispatch(Func func) const; size_t GetNumRow() const override = 0; size_t GetNumCol() const override = 0; size_t GetNumElem() const override = 0; + DMatrixType GetType() const override = 0; + TypeInfo GetElementType() const override; }; template @@ -113,7 +122,6 @@ class CSRDMatrixImpl : public CSRDMatrix { CSRDMatrixImpl() = delete; CSRDMatrixImpl(std::vector data, std::vector col_ind, std::vector row_ptr, size_t num_row, size_t num_col); - ~CSRDMatrixImpl() = default; CSRDMatrixImpl(const CSRDMatrixImpl&) = default; CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default; CSRDMatrixImpl& operator=(const CSRDMatrixImpl&) = default; @@ -122,30 +130,13 @@ class CSRDMatrixImpl : public CSRDMatrix { size_t GetNumRow() const override; size_t GetNumCol() const override; size_t GetNumElem() const override; + DMatrixType GetType() const override; friend class CSRDMatrix; static_assert(std::is_same::value || std::is_same::value, "ElementType must be either float32 or float64"); }; -template -inline auto -CSRDMatrix::Dispatch(Func func) const { - switch (element_type_) { - case TypeInfo::kFloat32: - return func(*dynamic_cast*>(this)); - break; - case TypeInfo::kFloat64: - return func(*dynamic_cast*>(this)); - break; - case TypeInfo::kUInt32: - case TypeInfo::kInvalid: - default: - LOG(FATAL) << "Invalid element type for the matrix: " << TypeInfoToString(element_type_); - return func(*dynamic_cast*>(this)); // avoid missing return error - } -} - } // namespace treelite #endif // TREELITE_DATA_H_ diff --git a/include/treelite/entry.h b/include/treelite/entry.h deleted file mode 100644 index 8a3fb0a5..00000000 --- a/include/treelite/entry.h +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2017-2020 by Contributors - * \file entry.h - * \author Hyunsu Cho - * \brief Entry type for Treelite predictor - */ -#ifndef TREELITE_ENTRY_H_ -#define TREELITE_ENTRY_H_ - -/*! \brief data layout. The value -1 signifies the missing value. - When the "missing" field is set to -1, the "fvalue" field is set to - NaN (Not a Number), so there is no danger for mistaking between - missing values and non-missing values. */ -union TreelitePredictorEntry { - int missing; - float fvalue; - // may contain extra fields later, such as qvalue -}; - -#endif // TREELITE_ENTRY_H_ diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 40652a1b..4ca0b122 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -8,139 +8,162 @@ #define TREELITE_PREDICTOR_H_ #include -#include #include +#include #include #include namespace treelite { +namespace predictor { + +/*! \brief data layout. The value -1 signifies the missing value. + When the "missing" field is set to -1, the "fvalue" field is set to + NaN (Not a Number), so there is no danger for mistaking between + missing values and non-missing values. */ +template +union Entry { + int missing; + ElementType fvalue; + // may contain extra fields later, such as qvalue +}; + +class PredictorOutput { + public: + virtual size_t GetNumRow() const = 0; + virtual size_t GetNumOutputGroup() const = 0; + + PredictorOutput() = default; + virtual ~PredictorOutput() = default; + + static std::unique_ptr Create( + TypeInfo leaf_output_type, size_t num_row, size_t num_output_group); +}; + +template +class PredictorOutputImpl : public PredictorOutput { + private: + std::vector preds_; + size_t num_row_; + size_t num_output_group_; + + friend class PredictorOutput; + + public: + size_t GetNumRow() const override; + size_t GetNumOutputGroup() const override; + std::vector& GetPreds(); + const std::vector& GetPreds() const; -/*! \brief sparse batch in Compressed Sparse Row (CSR) format */ -struct CSRBatch { - /*! \brief feature values */ - const float* data; - /*! \brief feature indices */ - const uint32_t* col_ind; - /*! \brief pointer to row headers; length of [num_row] + 1 */ - const size_t* row_ptr; - /*! \brief number of rows */ - size_t num_row; - /*! \brief number of columns (i.e. # of features used) */ - size_t num_col; + PredictorOutputImpl(size_t num_row, size_t num_output_group); }; -/*! \brief dense batch */ -struct DenseBatch { - /*! \brief feature values */ - const float* data; - /*! \brief value representing the missing value (usually nan) */ - float missing_value; - /*! \brief number of rows */ - size_t num_row; - /*! \brief number of columns (i.e. # of features used) */ - size_t num_col; +class SharedLibrary { + public: + using LibraryHandle = void*; + using FunctionHandle = void*; + SharedLibrary(); + ~SharedLibrary(); + void Load(const char* libpath); + FunctionHandle LoadFunction(const char* name) const; + template + HandleType LoadFunctionWithSignature(const char* name) const; + + private: + LibraryHandle handle_; + std::string libpath_; +}; + +class PredFunction { + public: + static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type, + const SharedLibrary& library, int num_feature, + int num_output_group); + PredFunction() = default; + virtual ~PredFunction() = default; + virtual TypeInfo GetThresholdType() const = 0; + virtual TypeInfo GetLeafOutputType() const = 0; + virtual size_t PredictBatch( + const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutput* out_pred) const = 0; +}; + +template +class PredFunctionImpl : public PredFunction { + public: + using PredFuncHandle = void*; + PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_output_group); + TypeInfo GetThresholdType() const override; + TypeInfo GetLeafOutputType() const override; + size_t PredictBatch( + const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutput* out_pred) const override; + + private: + PredFuncHandle handle_; + int num_feature_; + int num_output_group_; }; /*! \brief predictor class: wrapper for optimized prediction code */ class Predictor { public: /*! \brief opaque handle types */ - typedef void* QueryFuncHandle; - typedef void* PredFuncHandle; - typedef void* LibraryHandle; typedef void* ThreadPoolHandle; explicit Predictor(int num_worker_thread = -1); ~Predictor(); /*! * \brief load the prediction function from dynamic shared library. - * \param name name of dynamic shared library (.so/.dll/.dylib). + * \param libpath path of dynamic shared library (.so/.dll/.dylib). */ - void Load(const char* name); + void Load(const char* libpath); /*! * \brief unload the prediction function */ void Free(); - /*! * \brief Make predictions on a batch of data rows (synchronously). This * function internally divides the workload among all worker threads. - * \param batch a batch of rows + * \param dmat a batch of rows * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param out_result resulting output vector; use - * QueryResultSize() to allocate sufficient space + * \param out_result resulting output vector * \return length of the output vector, which is guaranteed to be less than * or equal to QueryResultSize() */ - size_t PredictBatch(const CSRBatch* batch, int verbose, - bool pred_margin, float* out_result); - size_t PredictBatch(const DenseBatch* batch, int verbose, - bool pred_margin, float* out_result); - - /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. - * \param batch a batch of rows - * \return length of prediction array - */ - inline size_t QueryResultSize(const CSRBatch* batch) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return batch->num_row * num_output_group_; - } + size_t PredictBatch( + const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutput* out_result) const; /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. - * \param batch a batch of rows - * \return length of prediction array + * \brief Allocate a buffer space that's sufficient to hold predicton for a given data matrix. + * The size of the buffer is given by QueryResultSize(). + * \param dmat a batch of rows + * \return Newly allocated buffer space */ - inline size_t QueryResultSize(const DenseBatch* batch) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return batch->num_row * num_output_group_; - } + std::unique_ptr AllocateOutputBuffer(const DMatrix* dmat) const; /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. - * \param batch a batch of rows - * \param rbegin beginning of range of rows - * \param rend end of range of rows + * \param dmat a batch of rows * \return length of prediction array */ - inline size_t QueryResultSize(const CSRBatch* batch, - size_t rbegin, size_t rend) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - CHECK(rbegin < rend && rend <= batch->num_row); - return (rend - rbegin) * num_output_group_; + inline size_t QueryResultSize(const DMatrix* dmat) const { + CHECK(pred_func_) << "A shared library needs to be loaded first using Load()"; + return dmat->GetNumRow() * num_output_group_; } /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. - * \param batch a batch of rows + * \param dmat a batch of rows * \param rbegin beginning of range of rows * \param rend end of range of rows * \return length of prediction array */ - inline size_t QueryResultSize(const DenseBatch* batch, - size_t rbegin, size_t rend) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - CHECK(rbegin < rend && rend <= batch->num_row); + inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const { + CHECK(pred_func_) << "A shared library needs to be loaded first using Load()"; + CHECK(rbegin < rend && rend <= dmat->GetNumRow()); return (rend - rbegin) * num_output_group_; } - /*! - * \brief Query the necessary size of array to hold the prediction for a - * single data row - * \return length of prediction array - */ - inline size_t QueryResultSizeSingleInst() const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return num_output_group_; - } /*! * \brief Get the number of output groups in the loaded model * The number is 1 for most tasks; @@ -150,7 +173,6 @@ class Predictor { inline size_t QueryNumOutputGroup() const { return num_output_group_; } - /*! * \brief Get the width (number of features) of each instance used to train * the loaded model @@ -159,7 +181,6 @@ class Predictor { inline size_t QueryNumFeature() const { return num_feature_; } - /*! * \brief Get name of post prediction transformation used to train the loaded model * \return name of prediction transformation @@ -167,7 +188,6 @@ class Predictor { inline std::string QueryPredTransform() const { return pred_transform_; } - /*! * \brief Get alpha value in sigmoid transformation used to train the loaded model * \return alpha value in sigmoid transformation @@ -175,7 +195,6 @@ class Predictor { inline float QuerySigmoidAlpha() const { return sigmoid_alpha_; } - /*! * \brief Get global bias which adjusting predicted margin scores * \return global bias @@ -185,13 +204,8 @@ class Predictor { } private: - LibraryHandle lib_handle_; - QueryFuncHandle num_output_group_query_func_handle_; - QueryFuncHandle num_feature_query_func_handle_; - QueryFuncHandle pred_transform_query_func_handle_; - QueryFuncHandle sigmoid_alpha_query_func_handle_; - QueryFuncHandle global_bias_query_func_handle_; - PredFuncHandle pred_func_handle_; + SharedLibrary lib_; + std::unique_ptr pred_func_; ThreadPoolHandle thread_pool_handle_; size_t num_output_group_; size_t num_feature_; @@ -199,12 +213,11 @@ class Predictor { float sigmoid_alpha_; float global_bias_; int num_worker_thread_; - - template - size_t PredictBatchBase_(const BatchType* batch, int verbose, - bool pred_margin, float* out_result); + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; }; +} // namespace predictor } // namespace treelite #endif // TREELITE_PREDICTOR_H_ diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 122e06c0..8dc1dc02 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -814,47 +814,17 @@ Model::Create() { return model; } +template +class ModelCreateDispatcher { + public: + inline static std::unique_ptr Dispatch() { + return Model::Create(); + } +}; + inline std::unique_ptr Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { - auto error_threshold_type = [threshold_type]() { - std::ostringstream oss; - oss << "Invalid threshold type: " << treelite::TypeInfoToString(threshold_type); - return oss.str(); - }; - auto error_leaf_output_type = [threshold_type, leaf_output_type]() { - std::ostringstream oss; - oss << "Cannot use leaf output type " << treelite::TypeInfoToString(leaf_output_type) - << " with threshold type " << treelite::TypeInfoToString(threshold_type); - return oss.str(); - }; - switch (threshold_type) { - case treelite::TypeInfo::kFloat32: - switch (leaf_output_type) { - case treelite::TypeInfo::kUInt32: - return treelite::Model::Create(); - case treelite::TypeInfo::kFloat32: - return treelite::Model::Create(); - default: - throw std::runtime_error(error_leaf_output_type()); - break; - } - break; - case treelite::TypeInfo::kFloat64: - switch (leaf_output_type) { - case treelite::TypeInfo::kUInt32: - return treelite::Model::Create(); - case treelite::TypeInfo::kFloat64: - return treelite::Model::Create(); - default: - throw std::runtime_error(error_leaf_output_type()); - break; - } - break; - default: - throw std::runtime_error(error_threshold_type()); - break; - } - return std::unique_ptr(nullptr); // avoid missing return value warning + return DispatchWithModelTypes(threshold_type, leaf_output_type); } template diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index 5fb8b7c7..66a23455 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace treelite { @@ -68,6 +69,62 @@ inline TypeInfo InferTypeInfoOf() { } } +/*! + * \brief Given the types for thresholds and leaf outputs, validate that they consist of a valid + * combination for a model and then dispatch a function with the corresponding template args + * \tparam Dispatcher Function object that takes in two template args. + * It must have a Dispatch() static function. + * \tparam Parameter pack, to forward an arbitrary number of args to Dispatcher::Dispatch() + * \param threshold_type TypeInfo indicating the type of thresholds + * \param leaf_output_type TypeInfo indicating the type of leaf outputs + * \param args Other extra parameters to pass to Dispatcher::Dispatch() + * \return Whatever that's returned by the dispatcher + */ +template class Dispatcher, typename ...Args> +inline auto DispatchWithModelTypes( + TypeInfo threshold_type, TypeInfo leaf_output_type, Args&& ...args) { + auto error_threshold_type = [threshold_type]() { + std::ostringstream oss; + oss << "Invalid threshold type: " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + auto error_leaf_output_type = [threshold_type, leaf_output_type]() { + std::ostringstream oss; + oss << "Cannot use leaf output type " << treelite::TypeInfoToString(leaf_output_type) + << " with threshold type " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + switch (threshold_type) { + case treelite::TypeInfo::kFloat32: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case treelite::TypeInfo::kFloat32: + return Dispatcher::Dispatch(std::forward(args)...); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + case treelite::TypeInfo::kFloat64: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case treelite::TypeInfo::kFloat64: + return Dispatcher::Dispatch(std::forward(args)...); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + default: + throw std::runtime_error(error_threshold_type()); + break; + } + return Dispatcher::Dispatch(std::forward(args)...); + // avoid missing return value warning +} + } // namespace treelite #endif // TREELITE_TYPEINFO_H_ diff --git a/src/annotator.cc b/src/annotator.cc index e1fe7656..726d96c9 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -144,7 +144,7 @@ void BranchAnnotator::Annotate(const Model& model, const CSRDMatrix* dmat, int nthread, int verbose) { TypeInfo threshold_type = model.GetThresholdType(); model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) { - CHECK(dmat->GetMatrixType() == threshold_type) + CHECK(dmat->GetElementType() == threshold_type) << "BranchAnnotator: the matrix type must match the threshold type of the model"; AnnotateImpl(handle, dmat, nthread, verbose, &this->counts); }); diff --git a/src/c_api/c_api_runtime.cc b/src/c_api/c_api_runtime.cc index 0834dc10..f86cc0f2 100644 --- a/src/c_api/c_api_runtime.cc +++ b/src/c_api/c_api_runtime.cc @@ -28,144 +28,73 @@ using TreeliteRuntimeAPIThreadLocalStore } // anonymous namespace -int TreeliteAssembleSparseBatch(const float* data, - const uint32_t* col_ind, - const size_t* row_ptr, - size_t num_row, size_t num_col, - CSRBatchHandle* out) { +int TreelitePredictorLoad(const char* library_path, int num_worker_thread, PredictorHandle* out) { API_BEGIN(); - CSRBatch* batch = new CSRBatch(); - batch->data = data; - batch->col_ind = col_ind; - batch->row_ptr = row_ptr; - batch->num_row = num_row; - batch->num_col = num_col; - *out = static_cast(batch); - API_END(); -} - -int TreeliteDeleteSparseBatch(CSRBatchHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TreeliteAssembleDenseBatch(const float* data, float missing_value, - size_t num_row, size_t num_col, - DenseBatchHandle* out) { - API_BEGIN(); - DenseBatch* batch = new DenseBatch(); - batch->data = data; - batch->missing_value = missing_value; - batch->num_row = num_row; - batch->num_col = num_col; - *out = static_cast(batch); - API_END(); -} - -int TreeliteDeleteDenseBatch(DenseBatchHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TreeliteBatchGetDimension(void* handle, - int batch_sparse, - size_t* out_num_row, - size_t* out_num_col) { - API_BEGIN(); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(handle); - *out_num_row = batch_->num_row; - *out_num_col = batch_->num_col; - } else { - const DenseBatch* batch_ = static_cast(handle); - *out_num_row = batch_->num_row; - *out_num_col = batch_->num_col; - } + auto predictor = std::make_unique(num_worker_thread); + predictor->Load(library_path); + *out = static_cast(predictor.release()); API_END(); } -int TreelitePredictorLoad(const char* library_path, - int num_worker_thread, - PredictorHandle* out) { +int TreelitePredictorPredictBatch( + PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, + PredictorOutputHandle output_buffer, size_t* out_result_size) { API_BEGIN(); - Predictor* predictor = new Predictor(num_worker_thread); - predictor->Load(library_path); - *out = static_cast(predictor); + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + auto* out_result = static_cast(output_buffer); + const size_t num_feature = predictor->QueryNumFeature(); + const std::string err_msg + = std::string("Too many columns (features) in the given batch. " + "Number of features must not exceed ") + std::to_string(num_feature); + CHECK_LE(dmat->GetNumCol(), num_feature) << err_msg; + *out_result_size = predictor->PredictBatch(dmat, verbose, (pred_margin != 0), out_result); API_END(); } -int TreelitePredictorPredictBatch(PredictorHandle handle, - void* batch, - int batch_sparse, - int verbose, - int pred_margin, - float* out_result, - size_t* out_result_size) { +int TreelitePredictorAllocateOutputBuffer( + PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_buffer) { API_BEGIN(); - Predictor* predictor_ = static_cast(handle); - const size_t num_feature = predictor_->QueryNumFeature(); - const std::string err_msg - = std::string("Too many columns (features) in the given batch. " - "Number of features must not exceed ") - + std::to_string(num_feature); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(batch); - CHECK_LE(batch_->num_col, num_feature) << err_msg; - *out_result_size = predictor_->PredictBatch(batch_, verbose, - (pred_margin != 0), out_result); - } else { - const DenseBatch* batch_ = static_cast(batch); - CHECK_LE(batch_->num_col, num_feature) << err_msg; - *out_result_size = predictor_->PredictBatch(batch_, verbose, - (pred_margin != 0), out_result); - } + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + std::unique_ptr output_buffer + = predictor->AllocateOutputBuffer(dmat); + *out_output_buffer = static_cast(output_buffer.release()); API_END(); } -int TreelitePredictorQueryResultSize(PredictorHandle handle, - void* batch, - int batch_sparse, - size_t* out) { +int TreelitePredictorDeleteOutputBuffer(PredictorOutputHandle handle) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(batch); - *out = predictor_->QueryResultSize(batch_); - } else { - const DenseBatch* batch_ = static_cast(batch); - *out = predictor_->QueryResultSize(batch_); - } + delete static_cast(handle); API_END(); } -int TreelitePredictorQueryResultSizeSingleInst(PredictorHandle handle, - size_t* out) { +int TreelitePredictorQueryResultSize(PredictorHandle handle, DMatrixHandle batch, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryResultSizeSingleInst(); + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + *out = predictor->QueryResultSize(dmat); API_END(); } int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryNumOutputGroup(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryNumOutputGroup(); API_END(); } int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryNumFeature(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryNumFeature(); API_END(); } int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - auto pred_transform = predictor_->QueryPredTransform(); + const auto* predictor = static_cast(handle); + auto pred_transform = predictor->QueryPredTransform(); std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; ret_str = pred_transform; *out = ret_str.c_str(); @@ -174,20 +103,20 @@ int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float* out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QuerySigmoidAlpha(); + const auto* predictor = static_cast(handle); + *out = predictor->QuerySigmoidAlpha(); API_END(); } int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryGlobalBias(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryGlobalBias(); API_END(); } int TreelitePredictorFree(PredictorHandle handle) { API_BEGIN(); - delete static_cast(handle); + delete static_cast(handle); API_END(); } diff --git a/src/data/data.cc b/src/data/data.cc index 268d6521..e853b1be 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -104,6 +104,11 @@ DenseDMatrix::Create( return std::unique_ptr(nullptr); } +TypeInfo +DenseDMatrix::GetElementType() const { + return element_type_; +} + template DenseDMatrixImpl::DenseDMatrixImpl( std::vector data, ElementType missing_value, size_t num_row, size_t num_col) @@ -128,6 +133,12 @@ DenseDMatrixImpl::GetNumElem() const { return num_row * num_col; } +template +DMatrixType +DenseDMatrixImpl::GetType() const { + return DMatrixType::kDense; +} + template std::unique_ptr CSRDMatrix::Create(std::vector data, std::vector col_ind, @@ -179,7 +190,7 @@ CSRDMatrix::Create(const char* filename, const char* format, int nthread, int ve } TypeInfo -CSRDMatrix::GetMatrixType() const { +CSRDMatrix::GetElementType() const { return element_type_; } @@ -191,22 +202,28 @@ CSRDMatrixImpl::CSRDMatrixImpl( num_row(num_col), num_col(num_col) {} -template +template size_t CSRDMatrixImpl::GetNumRow() const { return num_row; } -template +template size_t CSRDMatrixImpl::GetNumCol() const { return num_col; } -template +template size_t CSRDMatrixImpl::GetNumElem() const { return row_ptr.at(num_row); } +template +DMatrixType +CSRDMatrixImpl::GetType() const { + return DMatrixType::kSparseCSR; +} + } // namespace treelite diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index bc49e395..fb28d4f4 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -7,6 +7,8 @@ #include #include +#include +#include #include #include #include @@ -27,22 +29,12 @@ namespace { -enum class InputType : uint8_t { - kSparseBatch = 0, kDenseBatch = 1 -}; - struct InputToken { - InputType input_type; - const void* data; // pointer to input data + const treelite::DMatrix* dmat; // input data bool pred_margin; // whether to store raw margin or transformed scores - size_t num_feature; - // # features (columns) accepted by the tree ensemble model - size_t num_output_group; - // size of output per instance (row) - treelite::Predictor::PredFuncHandle pred_func_handle; - size_t rbegin, rend; - // range of instances (rows) assigned to each worker - float* out_pred; + const treelite::predictor::PredFunction* pred_func_; + size_t rbegin, rend; // range of instances (rows) assigned to each worker + treelite::predictor::PredictorOutput* out_pred; // buffer to store output from each worker }; @@ -50,55 +42,22 @@ struct OutputToken { size_t query_result_size; }; -using PredThreadPool = treelite::ThreadPool; - -inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) { -#ifdef _WIN32 - HMODULE handle = LoadLibraryA(name); -#else - void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); -#endif - return static_cast(handle); -} - -inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) { -#ifdef _WIN32 - FreeLibrary(static_cast(handle)); -#else - dlclose(static_cast(handle)); -#endif -} - -template -inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle, - const char* name) { -#ifdef _WIN32 - FARPROC func_handle = GetProcAddress(static_cast(lib_handle), name); -#else - void* func_handle = dlsym(static_cast(lib_handle), name); -#endif - return static_cast(func_handle); -} +using PredThreadPool + = treelite::predictor::ThreadPool; -template -inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature, - size_t rbegin, size_t rend, - float* out_pred, PredFunc func) { - CHECK_LE(batch->num_col, num_feature); - std::vector inst( - std::max(batch->num_col, num_feature), {-1}); - CHECK(rbegin < rend && rend <= batch->num_row); - CHECK(sizeof(size_t) < sizeof(int64_t) - || (rbegin <= static_cast(std::numeric_limits::max()) - && rend <= static_cast(std::numeric_limits::max()))); - const int64_t rbegin_ = static_cast(rbegin); - const int64_t rend_ = static_cast(rend); - const size_t num_col = batch->num_col; - const float* data = batch->data; - const uint32_t* col_ind = batch->col_ind; - const size_t* row_ptr = batch->row_ptr; +template +inline size_t PredLoop(const treelite::CSRDMatrixImpl* dmat, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + CHECK_LE(dmat->num_col, static_cast(num_feature)); + std::vector> inst( + std::max(dmat->num_col, static_cast(num_feature)), {-1}); + CHECK(rbegin < rend && rend <= dmat->num_row); + const size_t num_col = dmat->num_col; + const ElementType* data = dmat->data.data(); + const uint32_t* col_ind = dmat->col_ind.data(); + const size_t* row_ptr = dmat->row_ptr.data(); size_t total_output_size = 0; - for (int64_t rid = rbegin_; rid < rend_; ++rid) { + for (size_t rid = rbegin; rid < rend; ++rid) { const size_t ibegin = row_ptr[rid]; const size_t iend = row_ptr[rid + 1]; for (size_t i = ibegin; i < iend; ++i) { @@ -112,32 +71,25 @@ inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature, return total_output_size; } -template -inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature, - size_t rbegin, size_t rend, - float* out_pred, PredFunc func) { - const bool nan_missing = treelite::math::CheckNAN(batch->missing_value); - CHECK_LE(batch->num_col, num_feature); - std::vector inst( - std::max(batch->num_col, num_feature), {-1}); - CHECK(rbegin < rend && rend <= batch->num_row); - CHECK(sizeof(size_t) < sizeof(int64_t) - || (rbegin <= static_cast(std::numeric_limits::max()) - && rend <= static_cast(std::numeric_limits::max()))); - const int64_t rbegin_ = static_cast(rbegin); - const int64_t rend_ = static_cast(rend); - const size_t num_col = batch->num_col; - const float missing_value = batch->missing_value; - const float* data = batch->data; - const float* row; +template +inline size_t PredLoop(const treelite::DenseDMatrixImpl* dmat, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + const bool nan_missing = treelite::math::CheckNAN(dmat->missing_value); + CHECK_LE(dmat->num_col, static_cast(num_feature)); + std::vector> inst( + std::max(dmat->num_col, static_cast(num_feature)), {-1}); + CHECK(rbegin < rend && rend <= dmat->num_row); + const size_t num_col = dmat->num_col; + const ElementType missing_value = dmat->missing_value; + const ElementType* data = dmat->data.data(); + const ElementType* row = nullptr; size_t total_output_size = 0; - for (int64_t rid = rbegin_; rid < rend_; ++rid) { + for (size_t rid = rbegin; rid < rend; ++rid) { row = &data[rid * num_col]; for (size_t j = 0; j < num_col; ++j) { if (treelite::math::CheckNAN(row[j])) { CHECK(nan_missing) - << "The missing_value argument must be set to NaN if there is any " - << "NaN in the matrix."; + << "The missing_value argument must be set to NaN if there is any NaN in the matrix."; } else if (nan_missing || row[j] != missing_value) { inst[j].fvalue = row[j]; } @@ -150,149 +102,278 @@ inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature, return total_output_size; } -template -inline size_t PredictBatch_(const BatchType* batch, bool pred_margin, - size_t num_feature, size_t num_output_group, - treelite::Predictor::PredFuncHandle pred_func_handle, - size_t rbegin, size_t rend, - size_t expected_query_result_size, float* out_pred) { - CHECK(pred_func_handle != nullptr) - << "A shared library needs to be loaded first using Load()"; - /* Pass the correct prediction function to PredLoop. - We also need to specify how the function should be called. */ - size_t query_result_size; - // Dimension of output vector: - // can be either [num_data] or [num_class]*[num_data]. - // Note that size of prediction may be smaller than out_pred (this occurs - // when pred_function is set to "max_index"). - if (num_output_group > 1) { // multi-class classification task - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = - PredLoop(batch, num_feature, rbegin, rend, out_pred, - [pred_func, num_output_group, pred_margin] - (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { - return pred_func(inst, static_cast(pred_margin), - &out_pred[rid * num_output_group]); - }); - } else { // every other task - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = - PredLoop(batch, num_feature, rbegin, rend, out_pred, - [pred_func, pred_margin] - (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { - out_pred[rid] = pred_func(inst, static_cast(pred_margin)); - return 1; - }); +template +inline size_t PredLoop(const treelite::DMatrix* dmat, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + treelite::DMatrixType dmat_type = dmat->GetType(); + switch (dmat_type) { + case treelite::DMatrixType::kDense: { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop(dmat_, num_feature, rbegin, rend, out_pred, func); + } + case treelite::DMatrixType::kSparseCSR: { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop(dmat_, num_feature, rbegin, rend, out_pred, func); + } + default: + LOG(FATAL) << "Unrecognized data matrix type: " << static_cast(dmat_type); + return 0; } - return query_result_size; + } } // anonymous namespace namespace treelite { +namespace predictor { + +std::unique_ptr +PredictorOutput::Create(TypeInfo leaf_output_type, size_t num_row, size_t num_output_group) { + switch (leaf_output_type) { + case TypeInfo::kFloat32: + return std::make_unique>(num_row, num_output_group); + case TypeInfo::kFloat64: + return std::make_unique>(num_row, num_output_group); + case TypeInfo::kUInt32: + return std::make_unique>(num_row, num_output_group); + case TypeInfo::kInvalid: + default: + LOG(FATAL) << "Invalid leaf_output_type: " << TypeInfoToString(leaf_output_type); + return std::unique_ptr(nullptr); + } +} + +template +PredictorOutputImpl::PredictorOutputImpl(size_t num_row, size_t num_output_group) + : preds_(num_row * num_output_group, static_cast(0)), + num_row_(num_row), num_output_group_(num_output_group) {} + +template +size_t +PredictorOutputImpl::GetNumRow() const { + return num_row_; +} + +template +size_t +PredictorOutputImpl::GetNumOutputGroup() const { + return num_output_group_; +} + +template +std::vector& +PredictorOutputImpl::GetPreds() { + return preds_; +} + +template +const std::vector& +PredictorOutputImpl::GetPreds() const { + return preds_; +} + +SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {} + +SharedLibrary::~SharedLibrary() { + if (handle_) { +#ifdef _WIN32 + FreeLibrary(static_cast(handle_)); +#else + dlclose(static_cast(handle_)); +#endif + } +} + +void +SharedLibrary::Load(const char* libpath) { +#ifdef _WIN32 + HMODULE handle = LoadLibraryA(name); +#else + void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL); +#endif + CHECK(handle) << "Failed to load dynamic shared library `" << libpath << "'"; + handle_ = static_cast(handle); + libpath_ = std::string(libpath); +} + +SharedLibrary::FunctionHandle +SharedLibrary::LoadFunction(const char* name) const { +#ifdef _WIN32 + FARPROC func_handle = GetProcAddress(static_cast(handle_), name); +#else + void* func_handle = dlsym(static_cast(handle_), name); +#endif + CHECK(func_handle) + << "Dynamic shared library `" << libpath_ << "' does not contain a function " << name << "()."; + return static_cast(func_handle); +} + +template +HandleType +SharedLibrary::LoadFunctionWithSignature(const char* name) const { + auto func_handle = reinterpret_cast(LoadFunction(name)); + CHECK(func_handle) << "Dynamic shared library `" << libpath_ << "' does not contain a function " + << name << "() with the requested signature"; + return func_handle; +} + +template +class PredFunctionInitDispatcher { + public: + inline static std::unique_ptr Dispatch( + const SharedLibrary& library, int num_feature, int num_output_group) { + return std::make_unique>( + library, num_feature, num_output_group); + } +}; + +std::unique_ptr +PredFunction::Create( + TypeInfo threshold_type, TypeInfo leaf_output_type, const SharedLibrary& library, + int num_feature, int num_output_group) { + return DispatchWithModelTypes( + threshold_type, leaf_output_type, library, num_feature, num_output_group); +} + +template +PredFunctionImpl::PredFunctionImpl( + const SharedLibrary& library, int num_feature, int num_output_group) { + CHECK_GT(num_output_group, 0) << "num_output_group cannot be zero"; + if (num_output_group > 1) { // multi-class classification + handle_ = library.LoadFunction("predict_multiclass"); + } else { // everything else + handle_ = library.LoadFunction("predict"); + } + num_feature_ = num_feature; + num_output_group_ = num_output_group; +} + +template +TypeInfo +PredFunctionImpl::GetThresholdType() const { + return InferTypeInfoOf(); +} + +template +TypeInfo +PredFunctionImpl::GetLeafOutputType() const { + return InferTypeInfoOf(); +} + +template +size_t +PredFunctionImpl::PredictBatch( + const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutput* out_pred) const { + /* Pass the correct prediction function to PredLoop. + We also need to specify how the function should be called. */ + size_t result_size; + // Dimension of output vector: + // can be either [num_data] or [num_class]*[num_data]. + // Note that size of prediction may be smaller than out_pred (this occurs + // when pred_function is set to "max_index"). + CHECK(dmat->GetElementType() == GetThresholdType()) + << "Mismatched data type in the data matrix. Expected: " << TypeInfoToString(GetThresholdType()) + << ", Given: " << TypeInfoToString(dmat->GetElementType()); + CHECK(rbegin < rend && rend <= dmat->GetNumRow()); + size_t num_row = rend - rbegin; + auto* out_pred_ = dynamic_cast*>(out_pred); + CHECK(out_pred_); + if (num_output_group_ > 1) { // multi-class classification + using PredFunc = size_t (*)(Entry*, int, LeafOutputType*); + auto pred_func = reinterpret_cast(handle_); + CHECK(pred_func) << "The predict_multiclass() function has incorrect signature."; + auto pred_func_wrapper + = [pred_func, num_output_group=num_output_group_, pred_margin] + (int64_t rid, Entry* inst, LeafOutputType* out_pred) -> size_t { + return pred_func(inst, static_cast(pred_margin), + &out_pred[rid * num_output_group]); + }; + result_size = + PredLoop( + dmat, num_feature_, rbegin, rend, out_pred_->GetPreds().data(), pred_func_wrapper); + } else { // everything else + using PredFunc = LeafOutputType (*)(Entry*, int); + auto pred_func = reinterpret_cast(handle_); + CHECK(pred_func) << "The predict() function has incorrect signature."; + auto pred_func_wrapper + = [pred_func, pred_margin] + (int64_t rid, Entry* inst, LeafOutputType* out_pred) -> size_t { + out_pred[rid] = pred_func(inst, static_cast(pred_margin)); + return 1; + }; + result_size = + PredLoop( + dmat, num_feature_, rbegin, rend, out_pred_->GetPreds().data(), pred_func_wrapper); + } + return result_size; +} Predictor::Predictor(int num_worker_thread) - : lib_handle_(nullptr), - num_output_group_query_func_handle_(nullptr), - num_feature_query_func_handle_(nullptr), - pred_func_handle_(nullptr), + : pred_func_(nullptr), thread_pool_handle_(nullptr), - num_worker_thread_(num_worker_thread) {} + num_output_group_(0), + num_feature_(0), + sigmoid_alpha_(std::numeric_limits::quiet_NaN()), + global_bias_(std::numeric_limits::quiet_NaN()), + num_worker_thread_(num_worker_thread), + threshold_type_(TypeInfo::kInvalid), + leaf_output_type_(TypeInfo::kInvalid) {} Predictor::~Predictor() { - Free(); + if (thread_pool_handle_) { + Free(); + } } void -Predictor::Load(const char* name) { - lib_handle_ = OpenLibrary(name); - if (lib_handle_ == nullptr) { - LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'"; - } +Predictor::Load(const char* libpath) { + lib_.Load(libpath); + + using UnsignedQueryFunc = size_t (*)(); + using StringQueryFunc = const char* (*)(); + using FloatQueryFunc = float (*)(); /* 1. query # of output groups */ - num_output_group_query_func_handle_ - = LoadFunction(lib_handle_, "get_num_output_group"); - using UnsignedQueryFunc = size_t (*)(void); - auto uint_query_func - = reinterpret_cast(num_output_group_query_func_handle_); - CHECK(uint_query_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid get_num_output_group() function"; - num_output_group_ = uint_query_func(); + auto* num_output_group_query_func + = lib_.LoadFunctionWithSignature("get_num_output_group"); + num_output_group_ = num_output_group_query_func(); /* 2. query # of features */ - num_feature_query_func_handle_ - = LoadFunction(lib_handle_, "get_num_feature"); - uint_query_func = reinterpret_cast(num_feature_query_func_handle_); - CHECK(uint_query_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid get_num_feature() function"; - num_feature_ = uint_query_func(); + auto* num_feature_query_func + = lib_.LoadFunctionWithSignature("get_num_feature"); + num_feature_ = num_feature_query_func(); CHECK_GT(num_feature_, 0) << "num_feature cannot be zero"; /* 3. query # of pred_transform name */ - pred_transform_query_func_handle_ - = LoadFunction(lib_handle_, "get_pred_transform"); - using StringQueryFunc = const char* (*)(void); - auto str_query_func = - reinterpret_cast(pred_transform_query_func_handle_); - if (str_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_pred_transform() function"; - pred_transform_ = "unknown"; - } else { - pred_transform_ = str_query_func(); - } + auto* pred_transform_query_func + = lib_.LoadFunctionWithSignature("get_pred_transform"); + pred_transform_ = pred_transform_query_func(); /* 4. query # of sigmoid_alpha */ - sigmoid_alpha_query_func_handle_ - = LoadFunction(lib_handle_, "get_sigmoid_alpha"); - using FloatQueryFunc = float (*)(void); - auto float_query_func = - reinterpret_cast(sigmoid_alpha_query_func_handle_); - if (float_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_sigmoid_alpha() function"; - sigmoid_alpha_ = NAN; - } else { - sigmoid_alpha_ = float_query_func(); - } + auto* sigmoid_alpha_query_func + = lib_.LoadFunctionWithSignature("get_sigmoid_alpha"); + sigmoid_alpha_ = sigmoid_alpha_query_func(); /* 5. query # of global_bias */ - global_bias_query_func_handle_ - = LoadFunction(lib_handle_, "get_global_bias"); - float_query_func = reinterpret_cast(global_bias_query_func_handle_); - if (float_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_global_bias() function"; - global_bias_ = NAN; - } else { - global_bias_ = float_query_func(); - } + auto* global_bias_query_func = lib_.LoadFunctionWithSignature("get_global_bias"); + global_bias_ = global_bias_query_func(); - /* 6. load appropriate function for margin prediction */ + /* 6. Query the data type for thresholds and leaf outputs */ + auto* threshold_type_query_func + = lib_.LoadFunctionWithSignature("get_threshold_type"); + threshold_type_ = typeinfo_table.at(threshold_type_query_func()); + auto* leaf_output_type_query_func + = lib_.LoadFunctionWithSignature("get_leaf_output_type"); + leaf_output_type_ = typeinfo_table.at(leaf_output_type_query_func()); + + /* 7. load appropriate function for margin prediction */ CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero"; - if (num_output_group_ > 1) { // multi-class classification - pred_func_handle_ = LoadFunction(lib_handle_, - "predict_multiclass"); - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle_); - CHECK(pred_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid predict_multiclass() function"; - } else { // everything else - pred_func_handle_ = LoadFunction(lib_handle_, "predict"); - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle_); - CHECK(pred_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid predict() function"; - } + pred_func_ = PredFunction::Create( + threshold_type_, leaf_output_type_, lib_, + static_cast(num_feature_), static_cast(num_output_group_)); if (num_worker_thread_ == -1) { - num_worker_thread_ = std::thread::hardware_concurrency(); + num_worker_thread_ = static_cast(std::thread::hardware_concurrency()); } thread_pool_handle_ = static_cast( new PredThreadPool(num_worker_thread_ - 1, this, @@ -301,33 +382,11 @@ Predictor::Load(const char* name) { const Predictor* predictor) { InputToken input; while (incoming_queue->Pop(&input)) { - size_t query_result_size; const size_t rbegin = input.rbegin; const size_t rend = input.rend; - switch (input.input_type) { - case InputType::kSparseBatch: - { - const CSRBatch* batch = static_cast(input.data); - query_result_size - = PredictBatch_(batch, input.pred_margin, input.num_feature, - input.num_output_group, input.pred_func_handle, - rbegin, rend, - predictor->QueryResultSize(batch, rbegin, rend), - input.out_pred); - } - break; - case InputType::kDenseBatch: - { - const DenseBatch* batch = static_cast(input.data); - query_result_size - = PredictBatch_(batch, input.pred_margin, input.num_feature, - input.num_output_group, input.pred_func_handle, - rbegin, rend, - predictor->QueryResultSize(batch, rbegin, rend), - input.out_pred); - } - break; - } + size_t query_result_size + = predictor->pred_func_->PredictBatch( + input.dmat, rbegin, rend, input.pred_margin, input.out_pred); outgoing_queue->Push(OutputToken{query_result_size}); } })); @@ -335,14 +394,12 @@ Predictor::Load(const char* name) { void Predictor::Free() { - CloseLibrary(lib_handle_); delete static_cast(thread_pool_handle_); } -template static inline -std::vector SplitBatch(const BatchType* batch, size_t split_factor) { - const size_t num_row = batch->num_row; +std::vector SplitBatch(const DMatrix* dmat, size_t split_factor) { + const size_t num_row = dmat->GetNumRow(); CHECK_LE(split_factor, num_row); const size_t portion = num_row / split_factor; const size_t remainder = num_row % split_factor; @@ -359,26 +416,18 @@ std::vector SplitBatch(const BatchType* batch, size_t split_factor) { return row_ptr; } -template -inline size_t -Predictor::PredictBatchBase_(const BatchType* batch, int verbose, - bool pred_margin, float* out_result) { - static_assert(std::is_same::value - || std::is_same::value, - "PredictBatchBase_: unrecognized batch type"); +size_t +Predictor::PredictBatch( + const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutput* out_result) const { const double tstart = dmlc::GetTime(); - PredThreadPool* pool = static_cast(thread_pool_handle_); - const InputType input_type - = std::is_same::value - ? InputType::kSparseBatch : InputType::kDenseBatch; - InputToken request{input_type, static_cast(batch), pred_margin, - num_feature_, num_output_group_, pred_func_handle_, - 0, batch->num_row, out_result}; + + const size_t num_row = dmat->GetNumRow(); + auto* pool = static_cast(thread_pool_handle_); + InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result}; OutputToken response; - CHECK_GT(batch->num_row, 0); - const int nthread = std::min(num_worker_thread_, - static_cast(batch->num_row)); - const std::vector row_ptr = SplitBatch(batch, nthread); + CHECK_GT(num_row, 0); + const int nthread = std::min(num_worker_thread_, static_cast(num_row)); + const std::vector row_ptr = SplitBatch(dmat, nthread); for (int tid = 0; tid < nthread - 1; ++tid) { request.rbegin = row_ptr[tid]; request.rend = row_ptr[tid + 1]; @@ -386,14 +435,11 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, } size_t total_size = 0; { - // assign work to master + // assign work to the main thread const size_t rbegin = row_ptr[nthread - 1]; const size_t rend = row_ptr[nthread]; const size_t query_result_size - = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_, - pred_func_handle_, - rbegin, rend, QueryResultSize(batch, rbegin, rend), - out_result); + = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result); total_size += query_result_size; } for (int tid = 0; tid < nthread - 1; ++tid) { @@ -402,16 +448,16 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, } } // re-shape output if total_size < dimension of out_result - if (total_size < QueryResultSize(batch, 0, batch->num_row)) { + if (total_size < QueryResultSize(dmat, 0, num_row)) { CHECK_GT(num_output_group_, 1); - CHECK_EQ(total_size % batch->num_row, 0); - const size_t query_size_per_instance = total_size / batch->num_row; + CHECK_EQ(total_size % num_row, 0); + const size_t query_size_per_instance = total_size / num_row; CHECK_GT(query_size_per_instance, 0); CHECK_LT(query_size_per_instance, num_output_group_); - for (size_t rid = 0; rid < batch->num_row; ++rid) { + for (size_t rid = 0; rid < num_row; ++rid) { for (size_t k = 0; k < query_size_per_instance; ++k) { out_result[rid * query_size_per_instance + k] - = out_result[rid * num_output_group_ + k]; + = out_result[rid * num_output_group_ + k]; } } } @@ -422,16 +468,10 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, return total_size; } -size_t -Predictor::PredictBatch(const CSRBatch* batch, int verbose, - bool pred_margin, float* out_result) { - return PredictBatchBase_(batch, verbose, pred_margin, out_result); -} - -size_t -Predictor::PredictBatch(const DenseBatch* batch, int verbose, - bool pred_margin, float* out_result) { - return PredictBatchBase_(batch, verbose, pred_margin, out_result); +std::unique_ptr +Predictor::AllocateOutputBuffer(const DMatrix* dmat) const { + return PredictorOutput::Create(leaf_output_type_, dmat->GetNumRow(), num_output_group_); } +} // namespace predictor } // namespace treelite diff --git a/src/predictor/thread_pool/spsc_queue.h b/src/predictor/thread_pool/spsc_queue.h index 1fa10416..539c4bd8 100644 --- a/src/predictor/thread_pool/spsc_queue.h +++ b/src/predictor/thread_pool/spsc_queue.h @@ -14,6 +14,9 @@ #include #include +namespace treelite { +namespace predictor { + const constexpr int kL1CacheBytes = 64; /*! \brief Lock-free single-producer-single-consumer queue for each thread */ @@ -117,4 +120,7 @@ class SpscQueue { std::condition_variable cv_; }; +} // namespace predictor +} // namespace treelite + #endif // TREELITE_PREDICTOR_THREAD_POOL_SPSC_QUEUE_H_ diff --git a/src/predictor/thread_pool/thread_pool.h b/src/predictor/thread_pool/thread_pool.h index a356a441..cf56b8aa 100644 --- a/src/predictor/thread_pool/thread_pool.h +++ b/src/predictor/thread_pool/thread_pool.h @@ -19,6 +19,7 @@ #include "spsc_queue.h" namespace treelite { +namespace predictor { template class ThreadPool { @@ -123,6 +124,7 @@ class ThreadPool { } }; +} // namespace predictor } // namespace treelite #endif // TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ From 13ee2b623379fde2c8b93fdda6003e0171a2f314 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 3 Sep 2020 07:42:19 -0700 Subject: [PATCH 17/38] Get moder builder and runtime API working end-to-end --- include/treelite/annotator.h | 2 +- include/treelite/c_api_common.h | 3 +- include/treelite/c_api_runtime.h | 26 +- include/treelite/data.h | 4 +- include/treelite/predictor.h | 60 +-- include/treelite/typeinfo.h | 33 +- python/treelite/__init__.py | 3 +- python/treelite/annotator.py | 3 +- python/treelite/core.py | 196 +-------- python/treelite/frontend.py | 86 +++- python/treelite/sklearn/common.py | 6 +- python/treelite/sklearn/gbm_classifier.py | 8 +- .../treelite/sklearn/gbm_multi_classifier.py | 10 +- python/treelite/sklearn/gbm_regressor.py | 7 +- python/treelite/sklearn/rf_classifier.py | 7 +- .../treelite/sklearn/rf_multi_classifier.py | 10 +- python/treelite/sklearn/rf_regressor.py | 7 +- runtime/python/treelite_runtime/__init__.py | 4 +- runtime/python/treelite_runtime/predictor.py | 388 +++++++++--------- runtime/python/treelite_runtime/util.py | 34 ++ src/CMakeLists.txt | 1 - src/annotator.cc | 99 ++++- src/c_api/c_api.cc | 4 +- src/c_api/c_api_common.cc | 5 +- src/c_api/c_api_runtime.cc | 40 +- src/compiler/ast_native.cc | 60 +-- src/compiler/failsafe.cc | 41 +- src/compiler/native/header_template.h | 19 +- src/compiler/native/main_template.h | 29 +- src/data/data.cc | 44 +- src/frontend/builder.cc | 27 +- src/predictor/predictor.cc | 109 ++--- tests/python/conftest.py | 3 +- tests/python/metadata.py | 15 +- tests/python/test_basic.py | 9 +- tests/python/test_lightgbm_integration.py | 10 +- tests/python/test_model_builder.py | 26 +- tests/python/test_single_inst.py | 64 --- tests/python/test_xgboost_integration.py | 12 +- tests/python/util.py | 12 +- 40 files changed, 703 insertions(+), 823 deletions(-) delete mode 100644 tests/python/test_single_inst.py diff --git a/include/treelite/annotator.h b/include/treelite/annotator.h index a0b4ab2d..2e5972c1 100644 --- a/include/treelite/annotator.h +++ b/include/treelite/annotator.h @@ -24,7 +24,7 @@ class BranchAnnotator { * \param nthread number of threads to use * \param verbose whether to produce extra messages */ - void Annotate(const Model& model, const CSRDMatrix* dmat, int nthread, int verbose); + void Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose); /*! * \brief load branch annotation from a JSON file * \param fi input stream diff --git a/include/treelite/c_api_common.h b/include/treelite/c_api_common.h index b3c21364..61ae5bdb 100644 --- a/include/treelite/c_api_common.h +++ b/include/treelite/c_api_common.h @@ -60,7 +60,8 @@ TREELITE_DLL int TreeliteRegisterLogCallback(void (*callback)(const char*)); * \return 0 for success, -1 for failure */ TREELITE_DLL int TreeliteDMatrixCreateFromFile( - const char* path, const char* format, int nthread, int verbose, DMatrixHandle* out); + const char* path, const char* format, const char* data_type, int nthread, int verbose, + DMatrixHandle* out); /*! * \brief create DMatrix from a (in-memory) CSR matrix * \param data feature values diff --git a/include/treelite/c_api_runtime.h b/include/treelite/c_api_runtime.h index 56d7de93..17a8da4d 100644 --- a/include/treelite/c_api_runtime.h +++ b/include/treelite/c_api_runtime.h @@ -51,8 +51,9 @@ TREELITE_DLL int TreelitePredictorLoad( * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param output_buffer resulting output vector; use - * TreelitePredictorQueryResultSize() to allocate sufficient space + * \param out_result Resulting output vector. This pointer must point to an array of length + * TreelitePredictorQueryResultSize() and of type + * TreelitePredictorQueryLeafOutputType(). * \param out_result_size used to save length of the output vector, * which is guaranteed to be less than or equal to * TreelitePredictorQueryResultSize() @@ -60,23 +61,7 @@ TREELITE_DLL int TreelitePredictorLoad( */ TREELITE_DLL int TreelitePredictorPredictBatch( PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, - PredictorOutputHandle output_buffer, size_t* out_result_size); -/*! - * \brief Allocate a buffer space that's sufficient to hold predicton for a given data matrix. - * The size of the buffer is given by TreelitePredictorQueryResultSize(). - * \param handle predictor - * \param batch the data matrix containing a batch of rows - * \param out_output_buffer Newly allocated buffer space - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreelitePredictorAllocateOutputBuffer( - PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_buffer); -/*! - * \brief Delete a buffer space from memory - * \param handle the buffer space to delete - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreelitePredictorDeleteOutputBuffer(PredictorOutputHandle handle); + void* out_result, size_t* out_result_size); /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. @@ -129,6 +114,9 @@ TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, floa * \return 0 for success, -1 for failure */ TREELITE_DLL int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out); + +TREELITE_DLL int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out); +TREELITE_DLL int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out); /*! * \brief delete predictor from memory * \param handle predictor to remove diff --git a/include/treelite/data.h b/include/treelite/data.h index e5f34409..0c51e355 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -97,7 +97,7 @@ class CSRDMatrix : public DMatrix { TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row, size_t num_col); static std::unique_ptr Create( - const char* filename, const char* format, int nthread, int verbose); + const char* filename, const char* format, const char* data_type, int nthread, int verbose); size_t GetNumRow() const override = 0; size_t GetNumCol() const override = 0; size_t GetNumElem() const override = 0; @@ -133,8 +133,6 @@ class CSRDMatrixImpl : public CSRDMatrix { DMatrixType GetType() const override; friend class CSRDMatrix; - static_assert(std::is_same::value || std::is_same::value, - "ElementType must be either float32 or float64"); }; } // namespace treelite diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 4ca0b122..52626d76 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -27,36 +27,6 @@ union Entry { // may contain extra fields later, such as qvalue }; -class PredictorOutput { - public: - virtual size_t GetNumRow() const = 0; - virtual size_t GetNumOutputGroup() const = 0; - - PredictorOutput() = default; - virtual ~PredictorOutput() = default; - - static std::unique_ptr Create( - TypeInfo leaf_output_type, size_t num_row, size_t num_output_group); -}; - -template -class PredictorOutputImpl : public PredictorOutput { - private: - std::vector preds_; - size_t num_row_; - size_t num_output_group_; - - friend class PredictorOutput; - - public: - size_t GetNumRow() const override; - size_t GetNumOutputGroup() const override; - std::vector& GetPreds(); - const std::vector& GetPreds() const; - - PredictorOutputImpl(size_t num_row, size_t num_output_group); -}; - class SharedLibrary { public: using LibraryHandle = void*; @@ -82,9 +52,8 @@ class PredFunction { virtual ~PredFunction() = default; virtual TypeInfo GetThresholdType() const = 0; virtual TypeInfo GetLeafOutputType() const = 0; - virtual size_t PredictBatch( - const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, - PredictorOutput* out_pred) const = 0; + virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + void* out_pred) const = 0; }; template @@ -94,9 +63,8 @@ class PredFunctionImpl : public PredFunction { PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_output_group); TypeInfo GetThresholdType() const override; TypeInfo GetLeafOutputType() const override; - size_t PredictBatch( - const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, - PredictorOutput* out_pred) const override; + size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + void* out_pred) const override; private: PredFuncHandle handle_; @@ -128,19 +96,13 @@ class Predictor { * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param out_result resulting output vector + * \param out_result Resulting output vector. This pointer must point to an array of length + * QueryResultSize() and of type QueryLeafOutputType(). * \return length of the output vector, which is guaranteed to be less than * or equal to QueryResultSize() */ size_t PredictBatch( - const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutput* out_result) const; - /*! - * \brief Allocate a buffer space that's sufficient to hold predicton for a given data matrix. - * The size of the buffer is given by QueryResultSize(). - * \param dmat a batch of rows - * \return Newly allocated buffer space - */ - std::unique_ptr AllocateOutputBuffer(const DMatrix* dmat) const; + const DMatrix* dmat, int verbose, bool pred_margin, void* out_result) const; /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. @@ -202,6 +164,12 @@ class Predictor { inline float QueryGlobalBias() const { return global_bias_; } + inline TypeInfo QueryThresholdType() const { + return threshold_type_; + } + inline TypeInfo QueryLeafOutputType() const { + return leaf_output_type_; + } private: SharedLibrary lib_; @@ -215,6 +183,8 @@ class Predictor { int num_worker_thread_; TypeInfo threshold_type_; TypeInfo leaf_output_type_; + + mutable dmlc::OMPException exception_catcher_; }; } // namespace predictor diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index 66a23455..c10428b6 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -69,9 +69,40 @@ inline TypeInfo InferTypeInfoOf() { } } +/*! + * \brief Given a TypeInfo, dispatch a function with the corresponding template arg. More precisely, + * we shall call Dispatcher::Dispatch() where the template arg T corresponds to the + * `type` parameter. + * \tparam Dispatcher Function object that takes in one template arg. + * It must have a Dispatch() static function. + * \tparam Parameter pack, to forward an arbitrary number of args to Dispatcher::Dispatch() + * \param type TypeInfo corresponding to the template arg T with which + * Dispatcher::Dispatch() is called. + * \param args Other extra parameters to pass to Dispatcher::Dispatch() + * \return Whatever that's returned by the dispatcher + */ +template class Dispatcher, typename ...Args> +inline auto DispatchWithTypeInfo(TypeInfo type, Args&& ...args) { + switch (type) { + case TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kFloat32: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kFloat64: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kInvalid: + default: + throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type)); + } + return Dispatcher::Dispatch(std::forward(args)...); // avoid missing return error +} + /*! * \brief Given the types for thresholds and leaf outputs, validate that they consist of a valid - * combination for a model and then dispatch a function with the corresponding template args + * combination for a model and then dispatch a function with the corresponding template args. + * More precisely, we shall call Dispatcher::Dispatch() where + * the template args ThresholdType and LeafOutputType correspond to the parameters + * `threshold_type` and `leaf_output_type`, respectively. * \tparam Dispatcher Function object that takes in two template args. * It must have a Dispatch() static function. * \tparam Parameter pack, to forward an arbitrary number of args to Dispatcher::Dispatch() diff --git a/python/treelite/__init__.py b/python/treelite/__init__.py index ca010e8b..22fdff8f 100644 --- a/python/treelite/__init__.py +++ b/python/treelite/__init__.py @@ -4,7 +4,6 @@ """ import os -from .core import DMatrix from .frontend import Model, ModelBuilder from .annotator import Annotator from .contrib import create_shared, generate_makefile, generate_cmakelists @@ -15,5 +14,5 @@ with open(VERSION_FILE, 'r') as _f: __version__ = _f.read().strip() -__all__ = ['DMatrix', 'Model', 'ModelBuilder', 'Annotator', 'create_shared', 'generate_makefile', +__all__ = ['Model', 'ModelBuilder', 'Annotator', 'create_shared', 'generate_makefile', 'generate_cmakelists', 'sklearn', 'TreeliteError'] diff --git a/python/treelite/annotator.py b/python/treelite/annotator.py index b3c29995..14ec8c2d 100644 --- a/python/treelite/annotator.py +++ b/python/treelite/annotator.py @@ -3,8 +3,9 @@ import ctypes from .util import c_str, TreeliteError -from .core import _LIB, DMatrix, _check_call +from .core import _LIB, _check_call from .frontend import Model +from treelite_runtime import DMatrix class Annotator(): diff --git a/python/treelite/core.py b/python/treelite/core.py index ea67f94d..ccd7fde2 100644 --- a/python/treelite/core.py +++ b/python/treelite/core.py @@ -8,8 +8,7 @@ import numpy as np import scipy.sparse -from .compat import DataFrame -from .util import buffer_from_memory, c_str, py_str, _log_callback, TreeliteError +from .util import py_str, _log_callback, TreeliteError from .libpath import find_lib_path, TreeliteLibraryNotFound @@ -61,196 +60,3 @@ def c_array(ctype, values): ndarray.ctypes.data_as(*) to expose underlying buffer as C pointer. """ return (ctype * len(values))(*values) - - -PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', - 'int64': 'int', 'uint8': 'int', 'uint16': 'int', - 'uint32': 'int', 'uint64': 'int', 'float16': 'float', - 'float32': 'float', 'float64': 'float', 'bool': 'i'} - - -def _maybe_pandas_data(data, feature_names, feature_types): - """Extract internal data from pd.DataFrame for DMatrix data""" - if not isinstance(data, DataFrame): - return data, feature_names, feature_types - data_dtypes = data.dtypes - if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes): - bad_fields = [data.columns[i] for i, dtype in enumerate(data_dtypes) \ - if dtype.name not in PANDAS_DTYPE_MAPPER] - msg = "DataFrame.dtypes for data must be in, float, or bool. Did not " \ - + "expect the data types in fields " - raise ValueError(msg + ', '.join(bad_fields)) - if feature_names is None: - feature_names = data.columns.format() - if feature_types is None: - feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes] - data = data.values.astype('float') - return data, feature_names, feature_types - - -class DMatrix(): - """Data matrix used in Treelite. - - Parameters - ---------- - data : :py:class:`str ` / :py:class:`numpy.ndarray` /\ - :py:class:`scipy.sparse.csr_matrix` / :py:class:`pandas.DataFrame` - Data source. When data is :py:class:`str ` type, it indicates - that data should be read from a file. - data_format : :py:class:`str `, optional - Format of input data file. Applicable only when data is read from a - file. If missing, the svmlight (.libsvm) format is assumed. - missing : :py:class:`float `, optional - Value in the data that represents a missing entry. If set to ``None``, - ``numpy.nan`` will be used. - verbose : :py:class:`bool `, optional - Whether to print extra messages during construction - feature_names : :py:class:`list `, optional - Human-readable names for features - feature_types : :py:class:`list `, optional - Types for features - nthread : :py:class:`int `, optional - Number of threads - """ - - # pylint: disable=R0902,R0903,R0913 - - def __init__(self, data, data_format=None, missing=None, - feature_names=None, feature_types=None, - verbose=False, nthread=None): - if data is None: - raise TreeliteError('\'data\' argument cannot be None') - - data, feature_names, feature_types = _maybe_pandas_data(data, - feature_names, - feature_types) - if isinstance(data, (str,)): - self.handle = ctypes.c_void_p() - nthread = nthread if nthread is not None else 0 - data_format = data_format if data_format is not None else "libsvm" - _check_call(_LIB.TreeliteDMatrixCreateFromFile( - c_str(data), - c_str(data_format), - ctypes.c_int(nthread), - ctypes.c_int(1 if verbose else 0), - ctypes.byref(self.handle))) - elif isinstance(data, scipy.sparse.csr_matrix): - self._init_from_csr(data) - elif isinstance(data, scipy.sparse.csc_matrix): - self._init_from_csr(data.tocsr()) - elif isinstance(data, np.ndarray): - self._init_from_npy2d(data, missing) - else: # any type that's convertible to CSR matrix is O.K. - try: - csr = scipy.sparse.csr_matrix(data) - self._init_from_csr(csr) - except Exception as e: - raise TypeError(f'Cannot initialize DMatrix from {type(data).__name__}') from e - self.feature_names = feature_names - self.feature_types = feature_types - self._get_internals() # save handles for internal arrays - - def _init_from_csr(self, csr): - """Initialize data from a CSR (Compressed Sparse Row) matrix""" - if len(csr.indices) != len(csr.data): - raise ValueError('indices and data not of same length: {} vs {}' - .format(len(csr.indices), len(csr.data))) - if len(csr.indptr) != csr.shape[0] + 1: - raise ValueError('len(indptr) must be equal to 1 + [number of rows]' \ - + 'len(indptr) = {} vs 1 + [number of rows] = {}' - .format(len(csr.indptr), 1 + csr.shape[0])) - if csr.indptr[-1] != len(csr.data): - raise ValueError('last entry of indptr must be equal to len(data)' \ - + 'indptr[-1] = {} vs len(data) = {}' - .format(csr.indptr[-1], len(csr.data))) - self.handle = ctypes.c_void_p() - data = np.array(csr.data, copy=False, dtype=np.float32, order='C') - indices = np.array(csr.indices, copy=False, dtype=np.uintc, order='C') - indptr = np.array(csr.indptr, copy=False, dtype=np.uintp, order='C') - _check_call(_LIB.TreeliteDMatrixCreateFromCSR( - data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint)), - indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_size_t)), - ctypes.c_size_t(csr.shape[0]), - ctypes.c_size_t(csr.shape[1]), - ctypes.byref(self.handle))) - - def _init_from_npy2d(self, mat, missing): - """ - Initialize data from a 2-D numpy matrix. - If ``mat`` does not have ``order='C'`` (also known as row-major) or is not - contiguous, a temporary copy will be made. - If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be - made also. - Thus, as many as two temporary copies of data can be made. One should set - input layout and type judiciously to conserve memory. - """ - if len(mat.shape) != 2: - raise ValueError('Input numpy.ndarray must be two-dimensional') - # flatten the array by rows and ensure it is float32. - # we try to avoid data copies if possible - # (reshape returns a view when possible and we explicitly tell np.array to - # avoid copying) - data = np.array(mat.reshape(mat.size), copy=False, dtype=np.float32) - self.handle = ctypes.c_void_p() - missing = missing if missing is not None else np.nan - _check_call(_LIB.TreeliteDMatrixCreateFromMat( - data.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - ctypes.c_size_t(mat.shape[0]), - ctypes.c_size_t(mat.shape[1]), - ctypes.c_float(missing), - ctypes.byref(self.handle))) - - def _get_dims(self): - num_row = ctypes.c_size_t() - num_col = ctypes.c_size_t() - nelem = ctypes.c_size_t() - _check_call(_LIB.TreeliteDMatrixGetDimension(self.handle, - ctypes.byref(num_row), - ctypes.byref(num_col), - ctypes.byref(nelem))) - return (num_row.value, num_col.value, nelem.value) - - def _get_internals(self): - data = ctypes.POINTER(ctypes.c_float)() - col_ind = ctypes.POINTER(ctypes.c_uint32)() - row_ptr = ctypes.POINTER(ctypes.c_size_t)() - _check_call(_LIB.TreeliteDMatrixGetArrays(self.handle, - ctypes.byref(data), - ctypes.byref(col_ind), - ctypes.byref(row_ptr))) - num_row, num_col, nelem = self._get_dims() - - # DMatrix should mimick scipy.sparse.csr_matrix for - # proper duck typing in Predictor.from_csr() - self.data = np.frombuffer(buffer_from_memory(data, ctypes.sizeof(ctypes.c_float * nelem)), - dtype=np.float32) - self.indices = np.frombuffer(buffer_from_memory( - col_ind, - ctypes.sizeof(ctypes.c_uint32 * nelem)), dtype=np.uint32) - self.indptr = np.frombuffer(buffer_from_memory( - row_ptr, - ctypes.sizeof(ctypes.c_size_t * (num_row + 1))), dtype=np.uintp) - self.shape = (num_row, num_col) - self.size = nelem - - def __del__(self): - if self.handle is not None: - _check_call(_LIB.TreeliteDMatrixFree(self.handle)) - self.handle = None - - def __repr__(self): - return '<{}x{} sparse matrix of type treelite.DMatrix\n' \ - .format(self.shape[0], self.shape[1]) \ - + ' with {} stored elements in Compressed Sparse Row format>' \ - .format(self.size) - - def __str__(self): - # Print first and last 25 non-zero entries - preview = ctypes.c_char_p() - _check_call(_LIB.TreeliteDMatrixGetPreview(self.handle, - ctypes.byref(preview))) - return py_str(preview.value) - - -__all__ = ['DMatrix'] diff --git a/python/treelite/frontend.py b/python/treelite/frontend.py index c3087ee8..7413c6aa 100644 --- a/python/treelite/frontend.py +++ b/python/treelite/frontend.py @@ -7,6 +7,8 @@ import os from tempfile import TemporaryDirectory +import numpy as np + from .util import c_str, TreeliteError from .core import _LIB, c_array, _check_call from .contrib import create_shared, generate_makefile, generate_cmakelists, _toolchain_exist_check @@ -22,7 +24,7 @@ def _isascii(string): return False -class Model(): +class Model: """ Decision tree ensemble model @@ -371,7 +373,7 @@ def load(cls, filename, model_format): return Model(handle) -class ModelBuilder(): +class ModelBuilder: """ Builder class for tree ensemble model: provides tools to iteratively build an ensemble of decision trees @@ -392,7 +394,44 @@ class ModelBuilder(): parameters. """ - class Node(): + class Value: + """ + Value whose type may be specified at runtime + + Parameters + ---------- + type : str + Initial value of model handle + """ + CTYPES_PTR = { + 'uint32': ctypes.c_uint32, + 'float32': ctypes.c_float, + 'float64': ctypes.c_double + } + NUMPY_TYPE = { + 'uint32': np.uint32, + 'float32': np.float32, + 'float64': np.float64 + } + + def __init__(self, init_value, type): + self.type = type + self.handle = ctypes.c_void_p() + val = np.array([init_value], dtype=self.NUMPY_TYPE[type], order='C') + _check_call(_LIB.TreeliteTreeBuilderCreateValue( + val.ctypes.data_as(ctypes.POINTER(self.CTYPES_PTR[type])), + c_str(type), + ctypes.byref(self.handle) + )) + + def __repr__(self): + return '' + + def __del__(self): + if self.handle is not None: + _check_call(_LIB.TreeliteTreeBuilderDeleteValue(self.handle)) + + class Node: """Handle to a node in a tree""" def __init__(self): @@ -413,7 +452,7 @@ def set_root(self): raise TreeliteError('This node has never been inserted into a tree; ' + 'a node must be inserted before it can be a root') from e - def set_leaf_node(self, leaf_value): + def set_leaf_node(self, leaf_value, leaf_value_type='float32'): """ Set the node as a leaf node @@ -424,6 +463,8 @@ def set_leaf_node(self, leaf_value): :py:class:`float ` Usually a single leaf value (weight) of the leaf node. For multiclass random forest classifier, leaf_value should be a list of leaf weights. + leaf_value_type : str + Data type used for leaf_value (e.g. 'float32') """ if not self.empty: @@ -446,9 +487,11 @@ def set_leaf_node(self, leaf_value): try: if is_list: - leaf_value = [float(i) for i in leaf_value] + leaf_value = [ModelBuilder.Value(i, leaf_value_type) for i in leaf_value] + leaf_value_handle = c_array(ctypes.c_void_p, [x.handle for x in leaf_value]) else: - leaf_value = float(leaf_value) + leaf_value = ModelBuilder.Value(leaf_value, leaf_value_type) + leaf_value_handle = leaf_value.handle except TypeError as e: raise TreeliteError('leaf_value parameter should be either a ' + 'single float or a list of floats') from e @@ -458,13 +501,13 @@ def set_leaf_node(self, leaf_value): _check_call(_LIB.TreeliteTreeBuilderSetLeafVectorNode( self.tree.handle, ctypes.c_int(self.node_key), - c_array(ctypes.c_float, leaf_value), + leaf_value_handle, ctypes.c_size_t(len(leaf_value)))) else: _check_call(_LIB.TreeliteTreeBuilderSetLeafNode( self.tree.handle, ctypes.c_int(self.node_key), - ctypes.c_float(leaf_value))) + leaf_value_handle)) self.empty = False except AttributeError as e: raise TreeliteError('This node has never been inserted into a tree; ' + @@ -472,7 +515,8 @@ def set_leaf_node(self, leaf_value): # pylint: disable=R0913 def set_numerical_test_node(self, feature_id, opname, threshold, - default_left, left_child_key, right_child_key): + default_left, left_child_key, right_child_key, + threshold_type='float32'): """ Set the node as a test node with numerical split. The test is in the form ``[feature value] OP [threshold]``. Depending on the result of the test, @@ -493,6 +537,8 @@ def set_numerical_test_node(self, feature_id, opname, threshold, unique integer key to identify the left child node right_child_key : :py:class:`int ` unique integer key to identify the right child node + threshold_type : str + data type for threshold value (e.g. 'float32') """ if not self.empty: try: @@ -505,6 +551,7 @@ def set_numerical_test_node(self, feature_id, opname, threshold, 'delete it first and then add an empty node with ' + \ 'the same key.') try: + threshold_obj = ModelBuilder.Value(threshold, threshold_type) # automatically create child nodes that don't exist yet if left_child_key not in self.tree: self.tree[left_child_key] = ModelBuilder.Node() @@ -513,8 +560,9 @@ def set_numerical_test_node(self, feature_id, opname, threshold, _check_call(_LIB.TreeliteTreeBuilderSetNumericalTestNode( self.tree.handle, ctypes.c_int(self.node_key), - ctypes.c_uint(feature_id), c_str(opname), - ctypes.c_float(threshold), + ctypes.c_uint(feature_id), + c_str(opname), + threshold_obj.handle, ctypes.c_int(1 if default_left else 0), ctypes.c_int(left_child_key), ctypes.c_int(right_child_key))) @@ -578,12 +626,16 @@ def set_categorical_test_node(self, feature_id, left_categories, raise TreeliteError('This node has never been inserted into a tree; ' + 'a node must be inserted before it can be a test node') from e - class Tree(): + class Tree: """Handle to a decision tree in a tree ensemble Builder""" - def __init__(self): + def __init__(self, threshold_type='float32', leaf_output_type='float32'): self.handle = ctypes.c_void_p() - _check_call(_LIB.TreeliteCreateTreeBuilder(ctypes.byref(self.handle))) + _check_call(_LIB.TreeliteCreateTreeBuilder( + c_str(threshold_type), + c_str(leaf_output_type), + ctypes.byref(self.handle) + )) self.nodes = {} def __del__(self): @@ -640,8 +692,8 @@ def __repr__(self): return '\n' \ .format(len(self.nodes)) - def __init__(self, num_feature, num_output_group=1, - random_forest=False, **kwargs): + def __init__(self, num_feature, num_output_group=1, random_forest=False, + threshold_type='float32', leaf_output_type='float32', **kwargs): if not isinstance(num_feature, int): raise ValueError('num_feature must be of int type') if num_feature <= 0: @@ -655,6 +707,8 @@ def __init__(self, num_feature, num_output_group=1, ctypes.c_int(num_feature), ctypes.c_int(num_output_group), ctypes.c_int(1 if random_forest else 0), + c_str(threshold_type), + c_str(leaf_output_type), ctypes.byref(self.handle))) self._set_param(kwargs) self.trees = [] diff --git a/python/treelite/sklearn/common.py b/python/treelite/sklearn/common.py index 9f182275..9c7f3102 100644 --- a/python/treelite/sklearn/common.py +++ b/python/treelite/sklearn/common.py @@ -9,7 +9,8 @@ class SKLConverterBase: @classmethod def process_tree(cls, sklearn_tree, sklearn_model): """Process a scikit-learn Tree object""" - treelite_tree = treelite.ModelBuilder.Tree() + treelite_tree = treelite.ModelBuilder.Tree( + threshold_type='float64', leaf_output_type='float64') # Iterate over each node: node ID ranges from 0 to [node_count]-1 for node_id in range(sklearn_tree.node_count): @@ -38,9 +39,10 @@ def process_test_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): feature_id=sklearn_tree.feature[node_id], opname='<=', threshold=sklearn_tree.threshold[node_id], + threshold_type='float64', default_left=True, left_child_key=sklearn_tree.children_left[node_id], - right_child_key=sklearn_tree.children_right[node_id]) + right_child_key=sklearn_tree.children_right[node_id],) @classmethod def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): diff --git a/python/treelite/sklearn/gbm_classifier.py b/python/treelite/sklearn/gbm_classifier.py index 1ca0aab2..8a0ca986 100644 --- a/python/treelite/sklearn/gbm_classifier.py +++ b/python/treelite/sklearn/gbm_classifier.py @@ -17,9 +17,9 @@ def process_model(cls, sklearn_model): # Initialize Treelite model builder # Set random_forest=False for gradient boosted trees # Set pred_transform='sigmoid' to obtain probability predictions - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - random_forest=False, - pred_transform='sigmoid') + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, random_forest=False, pred_transform='sigmoid', + threshold_type='float64', leaf_output_type='float64') for i in range(sklearn_model.n_estimators): # Process i-th tree and add to the builder builder.append(cls.process_tree(sklearn_model.estimators_[i][0].tree_, @@ -34,4 +34,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # Need to shrink each leaf output by the learning rate leaf_value *= sklearn_model.learning_rate # Initialize the leaf node with given node ID - treelite_tree[node_id].set_leaf_node(leaf_value) + treelite_tree[node_id].set_leaf_node(leaf_value, leaf_value_type='float64') diff --git a/python/treelite/sklearn/gbm_multi_classifier.py b/python/treelite/sklearn/gbm_multi_classifier.py index c2111504..49422f53 100644 --- a/python/treelite/sklearn/gbm_multi_classifier.py +++ b/python/treelite/sklearn/gbm_multi_classifier.py @@ -19,10 +19,10 @@ def process_model(cls, sklearn_model): # Set random_forest=False for gradient boosted trees # Set num_output_group for multi-class classification # Set pred_transform='softmax' to obtain probability predictions - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - num_output_group=sklearn_model.n_classes_, - random_forest=False, - pred_transform='softmax') + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, num_output_group=sklearn_model.n_classes_, + random_forest=False, pred_transform='softmax', + threshold_type='float64', leaf_output_type='float64') # Process [number of iterations] * [number of classes] trees for i in range(sklearn_model.n_estimators): for k in range(sklearn_model.n_classes_): @@ -38,4 +38,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # Need to shrink each leaf output by the learning rate leaf_value *= sklearn_model.learning_rate # Initialize the leaf node with given node ID - treelite_tree[node_id].set_leaf_node(leaf_value) + treelite_tree[node_id].set_leaf_node(leaf_value, leaf_value_type='float64') diff --git a/python/treelite/sklearn/gbm_regressor.py b/python/treelite/sklearn/gbm_regressor.py index 89ceddca..2b5b8006 100644 --- a/python/treelite/sklearn/gbm_regressor.py +++ b/python/treelite/sklearn/gbm_regressor.py @@ -15,8 +15,9 @@ def process_model(cls, sklearn_model): "the option init='zero'") # Initialize Treelite model builder # Set random_forest=False for gradient boosted trees - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - random_forest=False) + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, random_forest=False, + threshold_type='float64', leaf_output_type='float64') for i in range(sklearn_model.n_estimators): # Process i-th tree and add to the builder builder.append(cls.process_tree(sklearn_model.estimators_[i][0].tree_, @@ -31,4 +32,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # Need to shrink each leaf output by the learning rate leaf_value *= sklearn_model.learning_rate # Initialize the leaf node with given node ID - treelite_tree[node_id].set_leaf_node(leaf_value) + treelite_tree[node_id].set_leaf_node(leaf_value, leaf_value_type='float64') diff --git a/python/treelite/sklearn/rf_classifier.py b/python/treelite/sklearn/rf_classifier.py index 410a6c0a..fc8323a5 100644 --- a/python/treelite/sklearn/rf_classifier.py +++ b/python/treelite/sklearn/rf_classifier.py @@ -10,8 +10,9 @@ class SKLRFClassifierMixin: def process_model(cls, sklearn_model): """Process a RandomForestClassifier (binary classifier) to convert it into a Treelite model""" - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - random_forest=True) + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, random_forest=True, + threshold_type='float64', leaf_output_type='float64') for i in range(sklearn_model.n_estimators): # Process i-th tree and add to the builder builder.append(cls.process_tree(sklearn_model.estimators_[i].tree_, @@ -28,4 +29,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # Compute the fraction of positive data points at this leaf node fraction_positive = float(leaf_count[1]) / leaf_count.sum() # The fraction above is now the leaf output - treelite_tree[node_id].set_leaf_node(fraction_positive) + treelite_tree[node_id].set_leaf_node(fraction_positive, leaf_value_type='float64') diff --git a/python/treelite/sklearn/rf_multi_classifier.py b/python/treelite/sklearn/rf_multi_classifier.py index ef7cace3..b3155271 100644 --- a/python/treelite/sklearn/rf_multi_classifier.py +++ b/python/treelite/sklearn/rf_multi_classifier.py @@ -11,10 +11,10 @@ def process_model(cls, sklearn_model): """Process a RandomForestClassifier (multi-class classifier) to convert it into a Treelite model""" # Must specify num_output_group and pred_transform - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - num_output_group=sklearn_model.n_classes_, - random_forest=True, - pred_transform='identity_multiclass') + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, num_output_group=sklearn_model.n_classes_, + random_forest=True, pred_transform='identity_multiclass', + threshold_type='float64', leaf_output_type='float64') for i in range(sklearn_model.n_estimators): # Process i-th tree and add to the builder builder.append(cls.process_tree(sklearn_model.estimators_[i].tree_, @@ -31,4 +31,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # Compute the probability distribution over label classes prob_distribution = leaf_count / leaf_count.sum() # The leaf output is the probability distribution - treelite_tree[node_id].set_leaf_node(prob_distribution) + treelite_tree[node_id].set_leaf_node(prob_distribution, leaf_value_type='float64') diff --git a/python/treelite/sklearn/rf_regressor.py b/python/treelite/sklearn/rf_regressor.py index c9f3b2cd..87550f76 100644 --- a/python/treelite/sklearn/rf_regressor.py +++ b/python/treelite/sklearn/rf_regressor.py @@ -11,8 +11,9 @@ def process_model(cls, sklearn_model): """Process a RandomForestRegressor to convert it into a Treelite model""" # Initialize Treelite model builder # Set random_forest=True for random forests - builder = treelite.ModelBuilder(num_feature=sklearn_model.n_features_, - random_forest=True) + builder = treelite.ModelBuilder( + num_feature=sklearn_model.n_features_, random_forest=True, + threshold_type='float64', leaf_output_type='float64') # Iterate over individual trees for i in range(sklearn_model.n_estimators): @@ -30,4 +31,4 @@ def process_leaf_node(cls, treelite_tree, sklearn_tree, node_id, sklearn_model): # The `value` attribute stores the output for every leaf node. leaf_value = sklearn_tree.value[node_id].squeeze() # Initialize the leaf node with given node ID - treelite_tree[node_id].set_leaf_node(leaf_value) + treelite_tree[node_id].set_leaf_node(leaf_value, leaf_value_type='float64') diff --git a/runtime/python/treelite_runtime/__init__.py b/runtime/python/treelite_runtime/__init__.py index 75d68cd2..2e2ff027 100644 --- a/runtime/python/treelite_runtime/__init__.py +++ b/runtime/python/treelite_runtime/__init__.py @@ -1,11 +1,11 @@ # coding: utf-8 import os -from .predictor import Predictor, Batch +from .predictor import Predictor, DMatrix from .util import TreeliteRuntimeError VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION') with open(VERSION_FILE) as f: __version__ = f.read().strip() -__all__ = ['Predictor', 'Batch', 'TreeliteRuntimeError'] +__all__ = ['Predictor', 'DMatrix', 'TreeliteRuntimeError'] diff --git a/runtime/python/treelite_runtime/predictor.py b/runtime/python/treelite_runtime/predictor.py index ff24e6af..02da0143 100644 --- a/runtime/python/treelite_runtime/predictor.py +++ b/runtime/python/treelite_runtime/predictor.py @@ -9,7 +9,8 @@ import numpy as np import scipy.sparse from .util import c_str, py_str, _log_callback, TreeliteRuntimeError, lineno, log_info, \ - lib_extension_current_platform + lib_extension_current_platform, type_info_to_ctypes_type, type_info_to_numpy_type, \ + numpy_type_to_type_info from .libpath import TreeliteRuntimeLibraryNotFound, find_lib_path @@ -50,183 +51,7 @@ def _check_call(ret): raise TreeliteRuntimeError(py_str(_LIB.TreeliteGetLastError())) -class PredictorEntry(ctypes.Union): - _fields_ = [('missing', ctypes.c_int), ('fvalue', ctypes.c_float)] - - -class Batch(object): - """Batch of rows to be used for prediction""" - - def __init__(self): - self.handle = None - self.kind = None - - def __del__(self): - if self.handle is not None: - if self.kind == 'sparse': - _check_call(_LIB.TreeliteDeleteSparseBatch(self.handle)) - elif self.kind == 'dense': - _check_call(_LIB.TreeliteDeleteDenseBatch(self.handle)) - else: - raise TreeliteRuntimeError('this batch has wrong value for `kind` field') - self.handle = None - self.kind = None - - def shape(self): - """ - Get dimensions of the batch - - Returns - ------- - dims : :py:class:`tuple ` of length 2 - (number of rows, number of columns) - """ - num_row = ctypes.c_size_t() - num_col = ctypes.c_size_t() - _check_call(_LIB.TreeliteBatchGetDimension( - self.handle, - ctypes.c_int(1 if self.kind == 'sparse' else 0), - ctypes.byref(num_row), - ctypes.byref(num_col))) - return (num_row.value, num_col.value) - - @classmethod - def from_npy2d(cls, mat, rbegin=0, rend=None, missing=None): - """ - Get a dense batch from a 2D numpy matrix. - If ``mat`` does not have ``order='C'`` (also known as row-major) or is not - contiguous, a temporary copy will be made. - If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be - made also. - Thus, as many as two temporary copies of data can be made. One should set - input layout and type judiciously to conserve memory. - - Parameters - ---------- - mat : object of type :py:class:`numpy.ndarray`, with dimension 2 - data matrix - rbegin : :py:class:`int `, optional - the index of the first row in the subset - rend : :py:class:`int `, optional - one past the index of the last row in the subset. If missing, set to - the end of the matrix. - missing : :py:class:`float `, optional - value indicating missing value. If missing, set to ``numpy.nan``. - - Returns - ------- - dense_batch : :py:class:`Batch` - a dense batch consisting of rows ``[rbegin, rend)`` - """ - if not isinstance(mat, np.ndarray): - raise ValueError('mat must be of type numpy.ndarray') - if len(mat.shape) != 2: - raise ValueError('Input numpy.ndarray must be two-dimensional') - num_row = mat.shape[0] - num_col = mat.shape[1] - rbegin = rbegin if rbegin is not None else 0 - rend = rend if rend is not None else num_row - if rbegin >= rend: - raise TreeliteRuntimeError('rbegin must be less than rend') - if rbegin < 0: - raise TreeliteRuntimeError('rbegin must be nonnegative') - if rend > num_row: - raise TreeliteRuntimeError('rend must be less than number of rows in mat') - # flatten the array by rows and ensure it is float32. - # we try to avoid data copies if possible - # (reshape returns a view when possible and we explicitly tell np.array to - # avoid copying) - data_subset = np.array(mat[rbegin:rend, :].reshape((rend - rbegin) * num_col), - copy=False, dtype=np.float32) - missing = missing if missing is not None else np.nan - - batch = Batch() - batch.handle = ctypes.c_void_p() - batch.kind = 'dense' - _check_call(_LIB.TreeliteAssembleDenseBatch( - data_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - ctypes.c_float(missing), - ctypes.c_size_t(rend - rbegin), - ctypes.c_size_t(num_col), - ctypes.byref(batch.handle))) - # save handles for internal arrays - batch.data = data_subset - # save pointer to mat so that it doesn't get garbage-collected prematurely - batch.mat = mat - return batch - - @classmethod - def from_csr(cls, csr, rbegin=None, rend=None): - """ - Get a sparse batch from a subset of rows in a CSR (Compressed Sparse Row) - matrix. The subset is given by the range ``[rbegin, rend)``. - - Parameters - ---------- - csr : object of class :py:class:`treelite.DMatrix` or \ - :py:class:`scipy.sparse.csr_matrix` - data matrix - rbegin : :py:class:`int `, optional - the index of the first row in the subset - rend : :py:class:`int `, optional - one past the index of the last row in the subset. If missing, set to - the end of the matrix. - - Returns - ------- - sparse_batch : :py:class:`Batch` - a sparse batch consisting of rows ``[rbegin, rend)`` - """ - # use duck typing so as to accomodate both scipy.sparse.csr_matrix - # and DMatrix without explictly importing any of them - try: - num_row = csr.shape[0] - num_col = csr.shape[1] - except AttributeError: - raise ValueError('csr must contain shape attribute') - except TypeError: - raise ValueError('csr.shape must be of tuple type') - except IndexError: - raise ValueError('csr.shape must be of length 2 (indicating 2D matrix)') - rbegin = rbegin if rbegin is not None else 0 - rend = rend if rend is not None else num_row - if rbegin >= rend: - raise TreeliteRuntimeError('rbegin must be less than rend') - if rbegin < 0: - raise TreeliteRuntimeError('rbegin must be nonnegative') - if rend > num_row: - raise TreeliteRuntimeError('rend must be less than number of rows in csr') - - # compute submatrix with rows [rbegin, rend) - ibegin = csr.indptr[rbegin] - iend = csr.indptr[rend] - data_subset = np.array(csr.data[ibegin:iend], copy=False, - dtype=np.float32, order='C') - indices_subset = np.array(csr.indices[ibegin:iend], copy=False, - dtype=np.uint32, order='C') - indptr_subset = np.array(csr.indptr[rbegin:(rend + 1)] - ibegin, copy=False, - dtype=np.uintp, order='C') - - batch = Batch() - batch.handle = ctypes.c_void_p() - batch.kind = 'sparse' - _check_call(_LIB.TreeliteAssembleSparseBatch( - data_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), - indices_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)), - indptr_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_size_t)), - ctypes.c_size_t(rend - rbegin), - ctypes.c_size_t(num_col), - ctypes.byref(batch.handle))) - # save handles for internal arrays - batch.data = data_subset - batch.indices = indices_subset - batch.indptr = indptr_subset - # save pointer to csr so that it doesn't get garbage-collected prematurely - batch.csr = csr - return batch - - -class Predictor(object): +class Predictor: """ Predictor class: loader for compiled shared libraries @@ -291,7 +116,7 @@ def __init__(self, libpath, nthread=None, verbose=False): _check_call(_LIB.TreelitePredictorQueryPredTransform( self.handle, ctypes.byref(pred_transform))) - self.pred_transform_ = bytes.decode(pred_transform.value) + self.pred_transform_ = py_str(pred_transform.value) # save # of sigmoid alpha sigmoid_alpha = ctypes.c_float() _check_call(_LIB.TreelitePredictorQuerySigmoidAlpha( @@ -304,51 +129,62 @@ def __init__(self, libpath, nthread=None, verbose=False): self.handle, ctypes.byref(global_bias))) self.global_bias_ = global_bias.value + threshold_type = ctypes.c_char_p() + _check_call(_LIB.TreelitePredictorQueryThresholdType( + self.handle, + ctypes.byref(threshold_type))) + self.threshold_type_ = py_str(threshold_type.value) + leaf_output_type = ctypes.c_char_p() + _check_call(_LIB.TreelitePredictorQueryLeafOutputType( + self.handle, + ctypes.byref(leaf_output_type))) + self.leaf_output_type_ = py_str(leaf_output_type.value) if verbose: log_info(__file__, lineno(), f'Dynamic shared library {path} has been successfully loaded into memory') - def predict(self, batch, verbose=False, pred_margin=False): + def predict(self, dmat, verbose=False, pred_margin=False): """ Perform batch prediction with a 2D sparse data matrix. Worker threads will internally divide up work for batch prediction. **Note that this function - may be called by only one thread at a time.** In order to use multiple - threads to process multiple prediction requests simultaneously, use - :py:meth:`predict_instance` instead. + may be called by only one thread at a time.** Parameters ---------- - batch: object of type :py:class:`Batch` + dmat: object of type :py:class:`DMatrix` batch of rows for which predictions will be made verbose : :py:class:`bool `, optional Whether to print extra messages during prediction pred_margin: :py:class:`bool `, optional whether to produce raw margins rather than transformed probabilities """ - if not isinstance(batch, Batch): - raise TreeliteRuntimeError('batch must be of type Batch') - if batch.handle is None or batch.kind is None: - raise TreeliteRuntimeError('batch cannot be empty') + if not isinstance(dmat, DMatrix): + raise TreeliteRuntimeError('dmat must be of type DMatrix') result_size = ctypes.c_size_t() _check_call(_LIB.TreelitePredictorQueryResultSize( self.handle, - batch.handle, - ctypes.c_int(1 if batch.kind == 'sparse' else 0), + dmat.handle, ctypes.byref(result_size))) - out_result = np.zeros(result_size.value, dtype=np.float32, order='C') + result_type = ctypes.c_char_p() + _check_call(_LIB.TreelitePredictorQueryLeafOutputType( + self.handle, + ctypes.byref(result_type))) + result_type = py_str(result_type.value) + out_result = np.zeros(result_size.value, + dtype=type_info_to_numpy_type(result_type), + order='C') out_result_size = ctypes.c_size_t() _check_call(_LIB.TreelitePredictorPredictBatch( self.handle, - batch.handle, - ctypes.c_int(1 if batch.kind == 'sparse' else 0), + dmat.handle, ctypes.c_int(1 if verbose else 0), ctypes.c_int(1 if pred_margin else 0), - out_result.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + out_result.ctypes.data_as(ctypes.POINTER(type_info_to_ctypes_type(result_type))), ctypes.byref(out_result_size))) idx = int(out_result_size.value) - res = out_result[0:idx].reshape((batch.shape()[0], -1)).squeeze() - if self.num_output_group_ > 1 and batch.shape()[0] != idx: + res = out_result[0:idx].reshape((dmat.shape[0], -1)).squeeze() + if self.num_output_group_ > 1 and dmat.shape[0] != idx: res = res.reshape((-1, self.num_output_group_)) return res @@ -381,3 +217,163 @@ def global_bias(self): def sigmoid_alpha(self): """Query sigmoid alpha of the model""" return self.sigmoid_alpha_ + +class DMatrix: + """Data matrix used in Treelite. + + Parameters + ---------- + data : :py:class:`str ` / :py:class:`numpy.ndarray` /\ + :py:class:`scipy.sparse.csr_matrix` / :py:class:`pandas.DataFrame` + Data source. When data is :py:class:`str ` type, it indicates + that data should be read from a file. + data_format : :py:class:`str `, optional + Format of input data file. Applicable only when data is read from a + file. If missing, the svmlight (.libsvm) format is assumed. + dtype : :py:class:`str `, optional + If specified, the data will be casted into the corresponding data type. + missing : :py:class:`float `, optional + Value in the data that represents a missing entry. If set to ``None``, + ``numpy.nan`` will be used. + verbose : :py:class:`bool `, optional + Whether to print extra messages during construction + feature_names : :py:class:`list `, optional + Human-readable names for features + feature_types : :py:class:`list `, optional + Types for features + nthread : :py:class:`int `, optional + Number of threads + """ + + # pylint: disable=R0902,R0903,R0913 + + def __init__(self, data, data_format=None, dtype=None, missing=None, + feature_names=None, feature_types=None, + verbose=False, nthread=None): + if data is None: + raise TreeliteRuntimeError('\'data\' argument cannot be None') + + self.handle = ctypes.c_void_p() + + if isinstance(data, (str,)): + nthread = nthread if nthread is not None else 0 + data_format = data_format if data_format is not None else "libsvm" + data_type = ctypes.c_char_p(None) if dtype is None else c_str(dtype) + _check_call(_LIB.TreeliteDMatrixCreateFromFile( + c_str(data), + c_str(data_format), + data_type, + ctypes.c_int(nthread), + ctypes.c_int(1 if verbose else 0), + ctypes.byref(self.handle))) + elif isinstance(data, scipy.sparse.csr_matrix): + self._init_from_csr(data, dtype=dtype) + elif isinstance(data, scipy.sparse.csc_matrix): + self._init_from_csr(data.tocsr(), dtype=dtype) + elif isinstance(data, np.ndarray): + self._init_from_npy2d(data, missing, dtype=dtype) + else: # any type that's convertible to CSR matrix is O.K. + try: + csr = scipy.sparse.csr_matrix(data) + self._init_from_csr(csr, dtype=dtype) + except Exception as e: + raise TypeError(f'Cannot initialize DMatrix from {type(data).__name__}') from e + self.feature_names = feature_names + self.feature_types = feature_types + num_row, num_col, nelem = self._get_dims() + self.shape = (num_row, num_col) + self.size = nelem + + def _init_from_csr(self, csr, dtype=None): + """Initialize data from a CSR (Compressed Sparse Row) matrix""" + if len(csr.indices) != len(csr.data): + raise ValueError('indices and data not of same length: {} vs {}' + .format(len(csr.indices), len(csr.data))) + if len(csr.indptr) != csr.shape[0] + 1: + raise ValueError('len(indptr) must be equal to 1 + [number of rows]' \ + + 'len(indptr) = {} vs 1 + [number of rows] = {}' + .format(len(csr.indptr), 1 + csr.shape[0])) + if csr.indptr[-1] != len(csr.data): + raise ValueError('last entry of indptr must be equal to len(data)' \ + + 'indptr[-1] = {} vs len(data) = {}' + .format(csr.indptr[-1], len(csr.data))) + + if dtype is None: + data_type = csr.data.dtype + else: + data_type = type_info_to_numpy_type(dtype) + data_type_code = numpy_type_to_type_info(data_type) + data_ptr_type = ctypes.POINTER(type_info_to_ctypes_type(data_type_code)) + if data_type_code not in ['float32', 'float64']: + raise ValueError('data should be either float32 or float64 type') + + data = np.array(csr.data, copy=False, dtype=data_type, order='C') + indices = np.array(csr.indices, copy=False, dtype=np.uintc, order='C') + indptr = np.array(csr.indptr, copy=False, dtype=np.uintp, order='C') + _check_call(_LIB.TreeliteDMatrixCreateFromCSR( + data.ctypes.data_as(data_ptr_type), + c_str(data_type_code), + indices.ctypes.data_as(ctypes.POINTER(ctypes.c_uint)), + indptr.ctypes.data_as(ctypes.POINTER(ctypes.c_size_t)), + ctypes.c_size_t(csr.shape[0]), + ctypes.c_size_t(csr.shape[1]), + ctypes.byref(self.handle))) + + def _init_from_npy2d(self, mat, missing, dtype=None): + """ + Initialize data from a 2-D numpy matrix. + If ``mat`` does not have ``order='C'`` (also known as row-major) or is not + contiguous, a temporary copy will be made. + If ``mat`` does not have ``dtype=numpy.float32``, a temporary copy will be + made also. + Thus, as many as two temporary copies of data can be made. One should set + input layout and type judiciously to conserve memory. + """ + if len(mat.shape) != 2: + raise ValueError('Input numpy.ndarray must be two-dimensional') + if dtype is None: + data_type = mat.dtype + else: + data_type = type_info_to_numpy_type(dtype) + data_type_code = numpy_type_to_type_info(data_type) + data_ptr_type = ctypes.POINTER(type_info_to_ctypes_type(data_type_code)) + if data_type_code not in ['float32', 'float64']: + raise ValueError('data should be either float32 or float64 type') + # flatten the array by rows and ensure it is float32. + # we try to avoid data copies if possible + # (reshape returns a view when possible and we explicitly tell np.array to + # avoid copying) + data = np.array(mat.reshape(mat.size), copy=False, dtype=data_type) + missing = missing if missing is not None else np.nan + missing = np.array([missing], dtype=data_type, order='C') + _check_call(_LIB.TreeliteDMatrixCreateFromMat( + data.ctypes.data_as(data_ptr_type), + c_str(data_type_code), + ctypes.c_size_t(mat.shape[0]), + ctypes.c_size_t(mat.shape[1]), + missing.ctypes.data_as(data_ptr_type), + ctypes.byref(self.handle))) + + def _get_dims(self): + num_row = ctypes.c_size_t() + num_col = ctypes.c_size_t() + nelem = ctypes.c_size_t() + _check_call(_LIB.TreeliteDMatrixGetDimension(self.handle, + ctypes.byref(num_row), + ctypes.byref(num_col), + ctypes.byref(nelem))) + return (num_row.value, num_col.value, nelem.value) + + def __del__(self): + if self.handle: + _check_call(_LIB.TreeliteDMatrixFree(self.handle)) + self.handle = None + + def __repr__(self): + return '<{}x{} sparse matrix of type treelite.DMatrix\n' \ + .format(self.shape[0], self.shape[1]) \ + + ' with {} stored elements in Compressed Sparse Row format>' \ + .format(self.size) + + +__all__ = ['Predictor', 'DMatrix'] diff --git a/runtime/python/treelite_runtime/util.py b/runtime/python/treelite_runtime/util.py index 8299d58f..d19631bb 100644 --- a/runtime/python/treelite_runtime/util.py +++ b/runtime/python/treelite_runtime/util.py @@ -6,6 +6,40 @@ import ctypes import time from sys import platform as _platform +import numpy as np + +_CTYPES_TYPE_TABLE = { + 'uint32': ctypes.c_uint32, + 'float32': ctypes.c_float, + 'float64': ctypes.c_double +} + +_NUMPY_TYPE_TABLE = { + 'uint32': np.uint32, + 'float32': np.float32, + 'float64': np.float64 +} + +_NUMPY_TYPE_TABLE_INV = { + np.uint32: 'unit32', + np.float32: 'float32', + np.float64: 'float64' +} + + +def type_info_to_ctypes_type(type_info): + """Obtain ctypes type corresponding to a given TypeInfo""" + return _CTYPES_TYPE_TABLE[type_info] + + +def type_info_to_numpy_type(type_info): + """Obtain ctypes type corresponding to a given TypeInfo""" + return _NUMPY_TYPE_TABLE[type_info] + + +def numpy_type_to_type_info(type_info): + """Obtain TypeInfo corresponding to a given NumPy type""" + return _NUMPY_TYPE_TABLE_INV[type_info] class TreeliteRuntimeError(Exception): diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 51d9d7ec..dd260151 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,7 +109,6 @@ target_sources(objtreelite_runtime predictor/thread_pool/thread_pool.h predictor/predictor.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_runtime.h - ${PROJECT_SOURCE_DIR}/include/treelite/entry.h ${PROJECT_SOURCE_DIR}/include/treelite/predictor.h ) diff --git a/src/annotator.cc b/src/annotator.cc index 726d96c9..49f30bb9 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -6,6 +6,7 @@ */ #include +#include #include #include #include @@ -58,10 +59,46 @@ void Traverse(const treelite::Tree& tree, } template -inline void ComputeBranchLoop(const treelite::ModelImpl& model, - const treelite::CSRDMatrixImpl* dmat, size_t rbegin, - size_t rend,int nthread, const size_t* count_row_ptr, - size_t* counts_tloc, Entry* inst) { +inline void ComputeBranchLoopImpl( + const treelite::ModelImpl& model, + const treelite::DenseDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { + const size_t ntree = model.trees.size(); + CHECK_LE(rbegin, rend); + CHECK_LT(static_cast(rend), std::numeric_limits::max()); + const size_t num_col = dmat->num_col; + const ThresholdType missing_value = dmat->missing_value; + const bool nan_missing = treelite::math::CheckNAN(missing_value); + const auto rbegin_i = static_cast(rbegin); + const auto rend_i = static_cast(rend); + #pragma omp parallel for schedule(static) num_threads(nthread) + for (int64_t rid = rbegin_i; rid < rend_i; ++rid) { + const int tid = omp_get_thread_num(); + const ThresholdType* row = &dmat->data[rid * num_col]; + const size_t off = dmat->num_col * tid; + const size_t off2 = count_row_ptr[ntree] * tid; + for (size_t j = 0; j < num_col; ++j) { + if (treelite::math::CheckNAN(row[j])) { + CHECK(nan_missing) + << "The missing_value argument must be set to NaN if there is any NaN in the matrix."; + } else if (nan_missing || row[j] != missing_value) { + inst[off + j].fvalue = row[j]; + } + } + for (size_t tree_id = 0; tree_id < ntree; ++tree_id) { + Traverse(model.trees[tree_id], &inst[off], &counts_tloc[off2 + count_row_ptr[tree_id]]); + } + for (size_t j = 0; j < num_col; ++j) { + inst[off + j].missing = -1; + } + } +} + +template +inline void ComputeBranchLoopImpl( + const treelite::ModelImpl& model, + const treelite::CSRDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -86,6 +123,35 @@ inline void ComputeBranchLoop(const treelite::ModelImpl +inline void ComputeBranchLoop(const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, size_t rbegin, + size_t rend, int nthread, const size_t* count_row_ptr, + size_t* counts_tloc, Entry* inst) { + CHECK(dmat->GetElementType() == treelite::InferTypeInfoOf()) + << "DMatrix has a wrong type. DMatrix has " + << treelite::TypeInfoToString(dmat->GetElementType()) << " whereas the model expects " + << treelite::TypeInfoToString(treelite::InferTypeInfoOf()); + switch (dmat->GetType()) { + case treelite::DMatrixType::kDense: { + const auto* dmat_ = dynamic_cast*>(dmat); + CHECK(dmat_) << "Dangling data matrix reference detected"; + ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc, inst); + break; + } + case treelite::DMatrixType::kSparseCSR: { + const auto* dmat_ = dynamic_cast*>(dmat); + CHECK(dmat_) << "Dangling data matrix reference detected"; + ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc, inst); + break; + } + default: + LOG(FATAL) + << "Annotator does not support DMatrix of type " << static_cast(dmat->GetType()); + break; + } +} + } // anonymous namespace namespace treelite { @@ -94,11 +160,8 @@ template inline void AnnotateImpl( const treelite::ModelImpl& model, - const treelite::CSRDMatrix* dmat, int nthread, int verbose, + const treelite::DMatrix* dmat, int nthread, int verbose, std::vector>* out_counts) { - auto* dmat_ = dynamic_cast*>(dmat); - CHECK(dmat_) << "BranchAnnotator: Dangling reference to CSRDMatrix detected"; - std::vector new_counts; std::vector counts_tloc; std::vector count_row_ptr; @@ -113,15 +176,17 @@ AnnotateImpl( new_counts.resize(count_row_ptr[ntree], 0); counts_tloc.resize(count_row_ptr[ntree] * nthread, 0); - std::vector> inst(nthread * dmat_->num_col, {-1}); - const size_t pstep = (dmat_->num_row + 19) / 20; + const size_t num_row = dmat->GetNumRow(); + const size_t num_col = dmat->GetNumCol(); + std::vector> inst(nthread * num_col, {-1}); + const size_t pstep = (num_row + 19) / 20; // interval to display progress - for (size_t rbegin = 0; rbegin < dmat_->num_row; rbegin += pstep) { - const size_t rend = std::min(rbegin + pstep, dmat_->num_row); - ComputeBranchLoop(model, dmat_, rbegin, rend, nthread, + for (size_t rbegin = 0; rbegin < num_row; rbegin += pstep) { + const size_t rend = std::min(rbegin + pstep, num_row); + ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0], &inst[0]); if (verbose > 0) { - LOG(INFO) << rend << " of " << dmat_->num_row << " rows processed"; + LOG(INFO) << rend << " of " << num_row << " rows processed"; } } @@ -141,11 +206,13 @@ AnnotateImpl( } void -BranchAnnotator::Annotate(const Model& model, const CSRDMatrix* dmat, int nthread, int verbose) { +BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) { TypeInfo threshold_type = model.GetThresholdType(); model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) { CHECK(dmat->GetElementType() == threshold_type) - << "BranchAnnotator: the matrix type must match the threshold type of the model"; + << "BranchAnnotator: the matrix type must match the threshold type of the model." + << "(current matrix type = " << TypeInfoToString(dmat->GetElementType()) + << " vs threshold type = " << TypeInfoToString(threshold_type) << ")"; AnnotateImpl(handle, dmat, nthread, verbose, &this->counts); }); } diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3f6033bd..1b2cc36b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -41,9 +41,7 @@ int TreeliteAnnotateBranch( const Model* model_ = static_cast(model); const auto* dmat_ = static_cast(dmat); CHECK(dmat_) << "Found a dangling reference to DMatrix"; - const auto* csr_dmat_ = dynamic_cast(dmat_); - CHECK(csr_dmat_) << "Annotator supports a sparse DMatrix for now"; - annotator->Annotate(*model_, csr_dmat_, nthread, verbose); + annotator->Annotate(*model_, dmat_, nthread, verbose); *out = static_cast(annotator.release()); API_END(); } diff --git a/src/c_api/c_api_common.cc b/src/c_api/c_api_common.cc index 1a8fa925..77c1110f 100644 --- a/src/c_api/c_api_common.cc +++ b/src/c_api/c_api_common.cc @@ -31,9 +31,10 @@ int TreeliteRegisterLogCallback(void (*callback)(const char*)) { } int TreeliteDMatrixCreateFromFile( - const char* path, const char* format, int nthread, int verbose, DMatrixHandle* out) { + const char* path, const char* format, const char* data_type, int nthread, int verbose, + DMatrixHandle* out) { API_BEGIN(); - std::unique_ptr mat = CSRDMatrix::Create(path, format, nthread, verbose); + std::unique_ptr mat = CSRDMatrix::Create(path, format, data_type, nthread, verbose); *out = static_cast(mat.release()); API_END(); } diff --git a/src/c_api/c_api_runtime.cc b/src/c_api/c_api_runtime.cc index f86cc0f2..e5d5482a 100644 --- a/src/c_api/c_api_runtime.cc +++ b/src/c_api/c_api_runtime.cc @@ -37,12 +37,11 @@ int TreelitePredictorLoad(const char* library_path, int num_worker_thread, Predi } int TreelitePredictorPredictBatch( - PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, - PredictorOutputHandle output_buffer, size_t* out_result_size) { + PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, void* out_result, + size_t* out_result_size) { API_BEGIN(); const auto* predictor = static_cast(handle); const auto* dmat = static_cast(batch); - auto* out_result = static_cast(output_buffer); const size_t num_feature = predictor->QueryNumFeature(); const std::string err_msg = std::string("Too many columns (features) in the given batch. " @@ -52,23 +51,6 @@ int TreelitePredictorPredictBatch( API_END(); } -int TreelitePredictorAllocateOutputBuffer( - PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_buffer) { - API_BEGIN(); - const auto* predictor = static_cast(handle); - const auto* dmat = static_cast(batch); - std::unique_ptr output_buffer - = predictor->AllocateOutputBuffer(dmat); - *out_output_buffer = static_cast(output_buffer.release()); - API_END(); -} - -int TreelitePredictorDeleteOutputBuffer(PredictorOutputHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - int TreelitePredictorQueryResultSize(PredictorHandle handle, DMatrixHandle batch, size_t* out) { API_BEGIN(); const auto* predictor = static_cast(handle); @@ -115,6 +97,24 @@ int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out) { API_END(); } +int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out) { + API_BEGIN() + const auto* predictor = static_cast(handle); + std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; + ret_str = TypeInfoToString(predictor->QueryThresholdType()); + *out = ret_str.c_str(); + API_END(); +} + +int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out) { + API_BEGIN() + const auto* predictor = static_cast(handle); + std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; + ret_str = TypeInfoToString(predictor->QueryLeafOutputType()); + *out = ret_str.c_str(); + API_END(); +} + int TreelitePredictorFree(PredictorHandle handle) { API_BEGIN(); delete static_cast(handle); diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index d5a704ad..e0adb997 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -192,24 +192,12 @@ class ASTNativeCompiler : public Compiler { = native::TypeInfoToCTypeString(InferTypeInfoOf()); const std::string leaf_output_type = native::TypeInfoToCTypeString(InferTypeInfoOf()); - const char* get_num_output_group_function_signature - = "size_t get_num_output_group(void)"; - const char* get_num_feature_function_signature - = "size_t get_num_feature(void)"; - const char* get_pred_transform_function_signature - = "const char* get_pred_transform(void)"; - const char* get_sigmoid_alpha_function_signature - = "float get_sigmoid_alpha(void)"; - const char* get_global_bias_function_signature - = "float get_global_bias(void)"; const std::string predict_function_signature = (num_output_group_ > 1) ? fmt::format("size_t predict_multiclass(union Entry* data, int pred_margin, {}* result)", leaf_output_type) : fmt::format("{} predict(union Entry* data, int pred_margin)", leaf_output_type); - const char* get_threshold_type_signature = "const char* get_threshold_type(void)"; - const char* get_leaf_output_type_signature = "const char* get_leaf_output_type(void)"; if (!array_is_categorical_.empty()) { array_is_categorical_ @@ -217,47 +205,31 @@ class ASTNativeCompiler : public Compiler { array_is_categorical_); } + const std::string query_functions_definition + = fmt::format(native::query_functions_definition_template, + "num_output_group"_a = num_output_group_, + "num_feature"_a = num_feature_, + "pred_transform"_a = pred_transform_, + "sigmoid_alpha"_a = sigmoid_alpha_, + "global_bias"_a = global_bias_, + "threshold_type_str"_a = TypeInfoToString(InferTypeInfoOf()), + "leaf_output_type_str"_a = TypeInfoToString(InferTypeInfoOf())); + AppendToBuffer(dest, fmt::format(native::main_start_template, "array_is_categorical"_a = array_is_categorical_, - "get_num_output_group_function_signature"_a - = get_num_output_group_function_signature, - "get_num_feature_function_signature"_a - = get_num_feature_function_signature, - "get_pred_transform_function_signature"_a - = get_pred_transform_function_signature, - "get_sigmoid_alpha_function_signature"_a - = get_sigmoid_alpha_function_signature, - "get_global_bias_function_signature"_a - = get_global_bias_function_signature, + "query_functions_definition"_a = query_functions_definition, "pred_transform_function"_a = pred_tranform_func_, - "predict_function_signature"_a = predict_function_signature, - "get_threshold_type_signature"_a = get_threshold_type_signature, - "threshold_type_str"_a = TypeInfoToString(InferTypeInfoOf()), - "get_leaf_output_type_signature"_a = get_leaf_output_type_signature, - "leaf_output_type_str"_a = TypeInfoToString(InferTypeInfoOf()), - "num_output_group"_a = num_output_group_, - "num_feature"_a = num_feature_, - "pred_transform"_a = pred_transform_, - "sigmoid_alpha"_a = sigmoid_alpha_, - "global_bias"_a = global_bias_), + "predict_function_signature"_a = predict_function_signature), indent); + const std::string query_functions_prototype + = fmt::format(native::query_functions_prototype_template, + "dllexport"_a = DLLEXPORT_KEYWORD); AppendToBuffer("header.h", fmt::format(native::header_template, "dllexport"_a = DLLEXPORT_KEYWORD, - "get_num_output_group_function_signature"_a - = get_num_output_group_function_signature, - "get_num_feature_function_signature"_a - = get_num_feature_function_signature, - "get_pred_transform_function_signature"_a - = get_pred_transform_function_signature, - "get_sigmoid_alpha_function_signature"_a - = get_sigmoid_alpha_function_signature, - "get_global_bias_function_signature"_a - = get_global_bias_function_signature, "predict_function_signature"_a = predict_function_signature, - "get_threshold_type_signature"_a = get_threshold_type_signature, - "get_leaf_output_type_signature"_a = get_leaf_output_type_signature, + "query_functions_prototype"_a = query_functions_prototype, "threshold_type"_a = threshold_type, "threshold_type_Node"_a = (param.quantize > 0 ? std::string("int") : threshold_type)), indent); diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index 8d301885..67b428f1 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -18,6 +18,8 @@ #include "./pred_transform.h" #include "./common/format_util.h" #include "./elf/elf_formatter.h" +#include "./native/main_template.h" +#include "./native/header_template.h" #if defined(_MSC_VER) || defined(_WIN32) #define DLLEXPORT_KEYWORD "__declspec(dllexport) " @@ -36,7 +38,7 @@ struct NodeStructValue { int cright; }; -const char* header_template = R"TREELITETEMPLATE( +const char* const header_template = R"TREELITETEMPLATE( #include #include #include @@ -62,23 +64,16 @@ struct Node {{ extern const struct Node nodes[]; extern const int nodes_row_ptr[]; -{dllexport}size_t get_num_output_group(void); -{dllexport}size_t get_num_feature(void); +{query_functions_prototype} {dllexport}{predict_function_signature}; )TREELITETEMPLATE"; -const char* main_template = R"TREELITETEMPLATE( +const char* const main_template = R"TREELITETEMPLATE( #include "header.h" {nodes_row_ptr} -size_t get_num_output_group(void) {{ - return {num_output_group}; -}} - -size_t get_num_feature(void) {{ - return {num_feature}; -}} +{query_functions_definition} {pred_transform_function} @@ -104,7 +99,7 @@ size_t get_num_feature(void) {{ }} )TREELITETEMPLATE"; -const char* return_multiclass_template = +const char* const return_multiclass_template = R"TREELITETEMPLATE( for (int i = 0; i < {num_output_group}; ++i) {{ result[i] = sum[i] + (float)({global_bias}); @@ -116,7 +111,7 @@ R"TREELITETEMPLATE( }} )TREELITETEMPLATE"; // only for multiclass classification -const char* return_template = +const char* const return_template = R"TREELITETEMPLATE( sum += (float)({global_bias}); if (!pred_margin) {{ @@ -126,7 +121,7 @@ R"TREELITETEMPLATE( }} )TREELITETEMPLATE"; -const char* arrays_template = R"TREELITETEMPLATE( +const char* const arrays_template = R"TREELITETEMPLATE( #include "header.h" {nodes} @@ -319,12 +314,22 @@ class FailSafeCompiler : public Compiler { std::tie(nodes, nodes_row_ptr) = FormatNodesArray(model_handle); } + const ModelParam model_param = model.GetParam(); + const std::string query_functions_definition + = fmt::format(native::query_functions_definition_template, + "num_output_group"_a = num_output_group_, + "num_feature"_a = num_feature_, + "pred_transform"_a = model_param.pred_transform, + "sigmoid_alpha"_a = model_param.sigmoid_alpha, + "global_bias"_a = model_param.global_bias, + "threshold_type_str"_a = TypeInfoToString(InferTypeInfoOf()), + "leaf_output_type_str"_a = TypeInfoToString(InferTypeInfoOf())); + main_program << fmt::format(main_template, "nodes_row_ptr"_a = nodes_row_ptr, + "query_functions_definition"_a = query_functions_definition, "pred_transform_function"_a = pred_tranform_func_, "predict_function_signature"_a = predict_function_signature, - "num_output_group"_a = num_output_group_, - "num_feature"_a = num_feature_, "num_tree"_a = model_handle.trees.size(), "compare_op"_a = GetCommonOp(model_handle), "accumulator_definition"_a = accumulator_definition, @@ -340,8 +345,12 @@ class FailSafeCompiler : public Compiler { "nodes"_a = nodes)); } + const std::string query_functions_prototype + = fmt::format(native::query_functions_prototype_template, + "dllexport"_a = DLLEXPORT_KEYWORD); files_["header.h"] = CompiledModel::FileEntry(fmt::format(header_template, "dllexport"_a = DLLEXPORT_KEYWORD, + "query_functions_prototype"_a = query_functions_prototype, "predict_function_signature"_a = predict_function_signature)); { diff --git a/src/compiler/native/header_template.h b/src/compiler/native/header_template.h index 7d9158aa..88e051ca 100644 --- a/src/compiler/native/header_template.h +++ b/src/compiler/native/header_template.h @@ -12,6 +12,17 @@ namespace treelite { namespace compiler { namespace native { +const char* const query_functions_prototype_template = +R"TREELITETEMPLATE( +{dllexport}size_t get_num_output_group(void); +{dllexport}size_t get_num_feature(void); +{dllexport}const char* get_pred_transform(void); +{dllexport}float get_sigmoid_alpha(void); +{dllexport}float get_global_bias(void); +{dllexport}const char* get_threshold_type(void); +{dllexport}const char* get_leaf_output_type(void); +)TREELITETEMPLATE"; + const char* const header_template = R"TREELITETEMPLATE( #include @@ -43,14 +54,8 @@ struct Node {{ extern const unsigned char is_categorical[]; -{dllexport}{get_num_output_group_function_signature}; -{dllexport}{get_num_feature_function_signature}; -{dllexport}{get_pred_transform_function_signature}; -{dllexport}{get_sigmoid_alpha_function_signature}; -{dllexport}{get_global_bias_function_signature}; +{query_functions_prototype} {dllexport}{predict_function_signature}; -{dllexport}{get_threshold_type_signature}; -{dllexport}{get_leaf_output_type_signature}; )TREELITETEMPLATE"; } // namespace native diff --git a/src/compiler/native/main_template.h b/src/compiler/native/main_template.h index fdb747d4..fa3b09ea 100644 --- a/src/compiler/native/main_template.h +++ b/src/compiler/native/main_template.h @@ -12,39 +12,44 @@ namespace treelite { namespace compiler { namespace native { -const char* const main_start_template = +const char* const query_functions_definition_template = R"TREELITETEMPLATE( -#include "header.h" - -{array_is_categorical}; - -{get_num_output_group_function_signature} {{ +size_t get_num_output_group(void) {{ return {num_output_group}; }} -{get_num_feature_function_signature} {{ +size_t get_num_feature(void) {{ return {num_feature}; }} -{get_pred_transform_function_signature} {{ +const char* get_pred_transform(void) {{ return "{pred_transform}"; }} -{get_sigmoid_alpha_function_signature} {{ +float get_sigmoid_alpha(void) {{ return {sigmoid_alpha}; }} -{get_global_bias_function_signature} {{ +float get_global_bias(void) {{ return {global_bias}; }} -{get_threshold_type_signature} {{ +const char* get_threshold_type(void) {{ return "{threshold_type_str}"; }} -{get_leaf_output_type_signature} {{ +const char* get_leaf_output_type(void) {{ return "{leaf_output_type_str}"; }} +)TREELITETEMPLATE"; + +const char* const main_start_template = +R"TREELITETEMPLATE( +#include "header.h" + +{array_is_categorical}; + +{query_functions_definition} {pred_transform_function} {predict_function_signature} {{ diff --git a/src/data/data.cc b/src/data/data.cc index e853b1be..d8baa758 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -13,14 +13,19 @@ namespace { -std::unique_ptr -CreateFromParser(dmlc::Parser* parser, int nthread, int verbose) { +template +inline static std::unique_ptr CreateFromParserImpl( + const char* filename, const char* format, int nthread, int verbose) { + std::unique_ptr> parser( + dmlc::Parser::Create(filename, 0, 1, format)); + const int max_thread = omp_get_max_threads(); nthread = (nthread == 0) ? max_thread : std::min(nthread, max_thread); - std::vector data; + std::vector data; std::vector col_ind; std::vector row_ptr; + row_ptr.resize(1, 0); size_t num_row = 0; size_t num_col = 0; size_t num_elem = 0; @@ -28,7 +33,7 @@ CreateFromParser(dmlc::Parser* parser, int nthread, int verbose std::vector max_col_ind(nthread, 0); parser->BeforeFirst(); while (parser->Next()) { - const dmlc::RowBlock& batch = parser->Value(); + const dmlc::RowBlock& batch = parser->Value(); num_row += batch.size; num_elem += batch.offset[batch.size]; const size_t top = data.size(); @@ -41,7 +46,9 @@ CreateFromParser(dmlc::Parser* parser, int nthread, int verbose i < static_cast(batch.offset[batch.size]); ++i) { const int tid = omp_get_thread_num(); const uint32_t index = batch.index[i]; - const float fvalue = (batch.value == nullptr) ? 1.0f : static_cast(batch.value[i]); + const ElementType fvalue + = ((batch.value == nullptr) ? static_cast(1) + : static_cast(batch.value[i])); const size_t offset = top + i - batch.offset[0]; data[offset] = fvalue; col_ind[offset] = index; @@ -63,6 +70,23 @@ CreateFromParser(dmlc::Parser* parser, int nthread, int verbose num_row, num_col); } +std::unique_ptr +CreateFromParser( + const char* filename, const char* format, treelite::TypeInfo dtype, int nthread, int verbose) { + switch (dtype) { + case treelite::TypeInfo::kFloat32: + return CreateFromParserImpl(filename, format, nthread, verbose); + case treelite::TypeInfo::kFloat64: + return CreateFromParserImpl(filename, format, nthread, verbose); + case treelite::TypeInfo::kUInt32: + return CreateFromParserImpl(filename, format, nthread, verbose); + default: + LOG(FATAL) << "Unrecognized TypeInfo: " << treelite::TypeInfoToString(dtype); + } + return CreateFromParserImpl(filename, format, nthread, verbose); + // avoid missing value warning +} + } // anonymous namespace namespace treelite { @@ -183,10 +207,10 @@ CSRDMatrix::Create(TypeInfo type, const void* data, const uint32_t* col_ind, con } std::unique_ptr -CSRDMatrix::Create(const char* filename, const char* format, int nthread, int verbose) { - std::unique_ptr> parser( - dmlc::Parser::Create(filename, 0, 1, format)); - return CreateFromParser(parser.get(), nthread, verbose); +CSRDMatrix::Create( + const char* filename, const char* format, const char* data_type, int nthread, int verbose) { + TypeInfo dtype = (data_type ? typeinfo_table.at(data_type) : TypeInfo::kFloat32); + return CreateFromParser(filename, format, dtype, nthread, verbose); } TypeInfo @@ -199,7 +223,7 @@ CSRDMatrixImpl::CSRDMatrixImpl( std::vector data, std::vector col_ind, std::vector row_ptr, size_t num_row, size_t num_col) : CSRDMatrix(), data(std::move(data)), col_ind(std::move(col_ind)), row_ptr(std::move(row_ptr)), - num_row(num_col), num_col(num_col) + num_row(num_row), num_col(num_col) {} template diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index bb0e724e..2d484189 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -127,29 +127,42 @@ Value::Create(T init_value) { return value; } +template +class CreateHandle { + public: + inline static std::shared_ptr Dispatch(const void* init_value) { + const auto* v_ptr = static_cast(init_value); + CHECK(v_ptr); + ValueType v = *v_ptr; + return std::make_shared(v); + } +}; + Value Value::Create(const void* init_value, TypeInfo type) { Value value; CHECK(type != TypeInfo::kInvalid) << "Type must be valid"; value.type_ = type; - value.Dispatch([init_value](auto& value_handle) { - using T = std::remove_reference_t; - T t = *static_cast(init_value); - value_handle = t; - }); + value.handle_ = DispatchWithTypeInfo(type, init_value); return value; } template T& Value::Get() { - return *static_cast(handle_.get()); + CHECK(handle_); + T* out = static_cast(handle_.get()); + CHECK(out); + return *out; } template const T& Value::Get() const { - return *static_cast(handle_.get()); + CHECK(handle_); + const T* out = static_cast(handle_.get()); + CHECK(out); + return *out; } TypeInfo diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index fb28d4f4..96166d01 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -34,8 +34,7 @@ struct InputToken { bool pred_margin; // whether to store raw margin or transformed scores const treelite::predictor::PredFunction* pred_func_; size_t rbegin, rend; // range of instances (rows) assigned to each worker - treelite::predictor::PredictorOutput* out_pred; - // buffer to store output from each worker + void* out_pred; // buffer to store output from each worker }; struct OutputToken { @@ -127,51 +126,6 @@ inline size_t PredLoop(const treelite::DMatrix* dmat, int num_feature, namespace treelite { namespace predictor { -std::unique_ptr -PredictorOutput::Create(TypeInfo leaf_output_type, size_t num_row, size_t num_output_group) { - switch (leaf_output_type) { - case TypeInfo::kFloat32: - return std::make_unique>(num_row, num_output_group); - case TypeInfo::kFloat64: - return std::make_unique>(num_row, num_output_group); - case TypeInfo::kUInt32: - return std::make_unique>(num_row, num_output_group); - case TypeInfo::kInvalid: - default: - LOG(FATAL) << "Invalid leaf_output_type: " << TypeInfoToString(leaf_output_type); - return std::unique_ptr(nullptr); - } -} - -template -PredictorOutputImpl::PredictorOutputImpl(size_t num_row, size_t num_output_group) - : preds_(num_row * num_output_group, static_cast(0)), - num_row_(num_row), num_output_group_(num_output_group) {} - -template -size_t -PredictorOutputImpl::GetNumRow() const { - return num_row_; -} - -template -size_t -PredictorOutputImpl::GetNumOutputGroup() const { - return num_output_group_; -} - -template -std::vector& -PredictorOutputImpl::GetPreds() { - return preds_; -} - -template -const std::vector& -PredictorOutputImpl::GetPreds() const { - return preds_; -} - SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {} SharedLibrary::~SharedLibrary() { @@ -263,8 +217,7 @@ PredFunctionImpl::GetLeafOutputType() const { template size_t PredFunctionImpl::PredictBatch( - const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, - PredictorOutput* out_pred) const { + const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, void* out_pred) const { /* Pass the correct prediction function to PredLoop. We also need to specify how the function should be called. */ size_t result_size; @@ -277,8 +230,6 @@ PredFunctionImpl::PredictBatch( << ", Given: " << TypeInfoToString(dmat->GetElementType()); CHECK(rbegin < rend && rend <= dmat->GetNumRow()); size_t num_row = rend - rbegin; - auto* out_pred_ = dynamic_cast*>(out_pred); - CHECK(out_pred_); if (num_output_group_ > 1) { // multi-class classification using PredFunc = size_t (*)(Entry*, int, LeafOutputType*); auto pred_func = reinterpret_cast(handle_); @@ -291,7 +242,8 @@ PredFunctionImpl::PredictBatch( }; result_size = PredLoop( - dmat, num_feature_, rbegin, rend, out_pred_->GetPreds().data(), pred_func_wrapper); + dmat, num_feature_, rbegin, rend, static_cast(out_pred), + pred_func_wrapper); } else { // everything else using PredFunc = LeafOutputType (*)(Entry*, int); auto pred_func = reinterpret_cast(handle_); @@ -304,7 +256,8 @@ PredFunctionImpl::PredictBatch( }; result_size = PredLoop( - dmat, num_feature_, rbegin, rend, out_pred_->GetPreds().data(), pred_func_wrapper); + dmat, num_feature_, rbegin, rend, static_cast(out_pred), + pred_func_wrapper); } return result_size; } @@ -380,15 +333,17 @@ Predictor::Load(const char* libpath) { [](SpscQueue* incoming_queue, SpscQueue* outgoing_queue, const Predictor* predictor) { - InputToken input; - while (incoming_queue->Pop(&input)) { - const size_t rbegin = input.rbegin; - const size_t rend = input.rend; - size_t query_result_size - = predictor->pred_func_->PredictBatch( - input.dmat, rbegin, rend, input.pred_margin, input.out_pred); - outgoing_queue->Push(OutputToken{query_result_size}); - } + predictor->exception_catcher_.Run([&]() { + InputToken input; + while (incoming_queue->Pop(&input)) { + const size_t rbegin = input.rbegin; + const size_t rend = input.rend; + size_t query_result_size + = predictor->pred_func_->PredictBatch( + input.dmat, rbegin, rend, input.pred_margin, input.out_pred); + outgoing_queue->Push(OutputToken{query_result_size}); + } + }); })); } @@ -416,9 +371,16 @@ std::vector SplitBatch(const DMatrix* dmat, size_t split_factor) { return row_ptr; } +template +class ShrinkResultToFit { + public: + inline static void Dispatch( + size_t num_row, size_t query_size_per_instance, size_t num_output_group, void* out_result); +}; + size_t Predictor::PredictBatch( - const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutput* out_result) const { + const DMatrix* dmat, int verbose, bool pred_margin, void* out_result) const { const double tstart = dmlc::GetTime(); const size_t num_row = dmat->GetNumRow(); @@ -454,12 +416,8 @@ Predictor::PredictBatch( const size_t query_size_per_instance = total_size / num_row; CHECK_GT(query_size_per_instance, 0); CHECK_LT(query_size_per_instance, num_output_group_); - for (size_t rid = 0; rid < num_row; ++rid) { - for (size_t k = 0; k < query_size_per_instance; ++k) { - out_result[rid * query_size_per_instance + k] - = out_result[rid * num_output_group_ + k]; - } - } + DispatchWithTypeInfo( + leaf_output_type_, num_row, query_size_per_instance, num_output_group_, out_result); } const double tend = dmlc::GetTime(); if (verbose > 0) { @@ -468,9 +426,16 @@ Predictor::PredictBatch( return total_size; } -std::unique_ptr -Predictor::AllocateOutputBuffer(const DMatrix* dmat) const { - return PredictorOutput::Create(leaf_output_type_, dmat->GetNumRow(), num_output_group_); +template +void +ShrinkResultToFit::Dispatch( + size_t num_row, size_t query_size_per_instance, size_t num_output_group, void* out_result) { + auto* out_result_ = static_cast(out_result); + for (size_t rid = 0; rid < num_row; ++rid) { + for (size_t k = 0; k < query_size_per_instance; ++k) { + out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_output_group + k]; + } + } } } // namespace predictor diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 54743b9d..feb8a8b0 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -5,6 +5,7 @@ import pytest import treelite +import treelite_runtime from .metadata import dataset_db @@ -17,7 +18,7 @@ def compute_annotation(dataset): model_format=dataset_db[dataset].format) if dataset_db[dataset].dtrain is None: return None - dtrain = treelite.DMatrix(dataset_db[dataset].dtrain) + dtrain = treelite_runtime.DMatrix(dataset_db[dataset].dtrain) annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) annotation_path = os.path.join(tmpdir, f'{dataset}.json') diff --git a/tests/python/metadata.py b/tests/python/metadata.py index f494547d..87b89450 100644 --- a/tests/python/metadata.py +++ b/tests/python/metadata.py @@ -8,28 +8,31 @@ _dpath = os.path.abspath(os.path.join(_current_dir, os.path.pardir, 'examples')) Dataset = collections.namedtuple( - 'Dataset', 'model format dtrain dtest libname expected_prob expected_margin is_multiclass') + 'Dataset', + 'model format dtrain dtest libname expected_prob expected_margin is_multiclass dtype') _dataset_db = { 'mushroom': Dataset(model='mushroom.model', format='xgboost', dtrain='agaricus.train', dtest='agaricus.test', libname='agaricus', expected_prob='agaricus.test.prob', expected_margin='agaricus.test.margin', - is_multiclass=False), + is_multiclass=False, dtype='float32'), 'dermatology': Dataset(model='dermatology.model', format='xgboost', dtrain='dermatology.train', dtest='dermatology.test', libname='dermatology', expected_prob='dermatology.test.prob', - expected_margin='dermatology.test.margin', is_multiclass=True), + expected_margin='dermatology.test.margin', is_multiclass=True, + dtype='float32'), 'letor': Dataset(model='mq2008.model', format='xgboost', dtrain='mq2008.train', dtest='mq2008.test', libname='letor', expected_prob=None, - expected_margin='mq2008.test.pred', is_multiclass=False), + expected_margin='mq2008.test.pred', is_multiclass=False, dtype='float32'), 'toy_categorical': Dataset(model='toy_categorical_model.txt', format='lightgbm', dtrain=None, dtest='toy_categorical.test', libname='toycat', expected_prob=None, - expected_margin='toy_categorical.test.pred', is_multiclass=False), + expected_margin='toy_categorical.test.pred', is_multiclass=False, + dtype='float64'), 'sparse_categorical': Dataset(model='sparse_categorical_model.txt', format='lightgbm', dtrain=None, dtest='sparse_categorical.test', libname='sparsecat', expected_prob=None, expected_margin='sparse_categorical.test.margin', - is_multiclass=False) + is_multiclass=False, dtype='float64') } diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 8979f1c0..d9151b43 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -7,6 +7,7 @@ from zipfile import ZipFile import pytest +import numpy as np from scipy.sparse import csr_matrix import treelite import treelite_runtime @@ -122,10 +123,10 @@ def test_deficient_matrix(tmpdir): model.export_lib(toolchain=toolchain, libpath=libpath, params={'quantize': 1}, verbose=True) X = csr_matrix(([], ([], [])), shape=(3, 3)) - batch = treelite_runtime.Batch.from_csr(X) + dmat = treelite_runtime.DMatrix(X, dtype='float32') predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True) assert predictor.num_feature == 127 - predictor.predict(batch) # should not crash + predictor.predict(dmat) # should not crash def test_too_wide_matrix(tmpdir): @@ -137,10 +138,10 @@ def test_too_wide_matrix(tmpdir): model.export_lib(toolchain=toolchain, libpath=libpath, params={'quantize': 1}, verbose=True) X = csr_matrix(([], ([], [])), shape=(3, 1000)) - batch = treelite_runtime.Batch.from_csr(X) + dmat = treelite_runtime.DMatrix(X, dtype='float32') predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True) assert predictor.num_feature == 127 - pytest.raises(treelite_runtime.TreeliteRuntimeError, predictor.predict, batch) + pytest.raises(treelite_runtime.TreeliteRuntimeError, predictor.predict, dmat) def test_set_tree_limit(): diff --git a/tests/python/test_lightgbm_integration.py b/tests/python/test_lightgbm_integration.py index b61b56a8..93b8d6c0 100644 --- a/tests/python/test_lightgbm_integration.py +++ b/tests/python/test_lightgbm_integration.py @@ -44,8 +44,8 @@ def test_lightgbm_multiclass_classification(tmpdir, objective, toolchain): model.export_lib(toolchain=toolchain, libpath=libpath, params={'quantize': 1}, verbose=True) predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True) - batch = treelite_runtime.Batch.from_npy2d(X_test) - out_pred = predictor.predict(batch) + dmat = treelite_runtime.DMatrix(X_test, dtype='float64') + out_pred = predictor.predict(dmat) expected_pred = bst.predict(X_test) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) @@ -72,12 +72,12 @@ def test_lightgbm_binary_classification(tmpdir, objective, toolchain): model = treelite.Model.load(model_path, model_format='lightgbm') libpath = os.path.join(tmpdir, f'agaricus_{objective}' + _libext()) - batch = treelite_runtime.Batch.from_csr(treelite.DMatrix(dtest_path)) + dmat = treelite_runtime.DMatrix(dtest_path, dtype='float64') model.export_lib(toolchain=toolchain, libpath=libpath, params={}, verbose=True) predictor = treelite_runtime.Predictor(libpath, verbose=True) - out_prob = predictor.predict(batch) + out_prob = predictor.predict(dmat) np.testing.assert_almost_equal(out_prob, expected_prob, decimal=5) - out_margin = predictor.predict(batch, pred_margin=True) + out_margin = predictor.predict(dmat, pred_margin=True) np.testing.assert_almost_equal(out_margin, expected_margin, decimal=5) diff --git a/tests/python/test_model_builder.py b/tests/python/test_model_builder.py index f305aa90..6e31cf30 100644 --- a/tests/python/test_model_builder.py +++ b/tests/python/test_model_builder.py @@ -97,7 +97,7 @@ def test_model_builder(tmpdir, use_annotation, quantize, toolchain): annotation_path = os.path.join(tmpdir, 'annotation.json') if use_annotation: - dtrain = treelite.DMatrix(dataset_db['mushroom'].dtrain) + dtrain = treelite_runtime.DMatrix(dataset_db['mushroom'].dtrain, dtype='float32') annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) annotator.save(path=annotation_path) @@ -119,8 +119,8 @@ def test_model_builder(tmpdir, use_annotation, quantize, toolchain): def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): # pylint: disable=too-many-locals """Convert scikit-learn multi-class classifier""" - if clazz == RandomForestClassifier: - pytest.xfail(reason='Need to use float64 thresholds to obtain correct result') + #if clazz == RandomForestClassifier: + # pytest.xfail(reason='Need to use float64 thresholds to obtain correct result') X, y = load_iris(return_X_y=True) kwargs = {} @@ -136,7 +136,7 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): assert (model.num_tree == clf.n_estimators * (clf.n_classes_ if clazz == GradientBoostingClassifier else 1)) - dtrain = treelite.DMatrix(X) + dtrain = treelite_runtime.DMatrix(X, dtype='float64') annotation_path = os.path.join(tmpdir, 'annotation.json') annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) @@ -152,8 +152,7 @@ def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): ('softmax' if clazz == GradientBoostingClassifier else 'identity_multiclass')) assert predictor.global_bias == 0.0 assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X) - out_prob = predictor.predict(batch) + out_prob = predictor.predict(dtrain) np.testing.assert_almost_equal(out_prob, expected_prob) @@ -176,7 +175,7 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): assert model.num_output_group == 1 assert model.num_tree == clf.n_estimators - dtrain = treelite.DMatrix(X) + dtrain = treelite_runtime.DMatrix(X, dtype='float64') annotation_path = os.path.join(tmpdir, 'annotation.json') annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) @@ -192,8 +191,7 @@ def test_skl_converter_binary_classifier(tmpdir, clazz, toolchain): == ('sigmoid' if clazz == GradientBoostingClassifier else 'identity')) assert predictor.global_bias == 0.0 assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X) - out_prob = predictor.predict(batch) + out_prob = predictor.predict(dtrain) np.testing.assert_almost_equal(out_prob, expected_prob) @@ -215,7 +213,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t assert model.num_output_group == 1 assert model.num_tree == clf.n_estimators - dtrain = treelite.DMatrix(X) + dtrain = treelite_runtime.DMatrix(X, dtype='float64') annotation_path = os.path.join(tmpdir, 'annotation.json') annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) @@ -230,8 +228,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t assert predictor.pred_transform == 'identity' assert predictor.global_bias == 0.0 assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X) - out_pred = predictor.predict(batch) + out_pred = predictor.predict(dtrain) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) @@ -275,8 +272,9 @@ def test_node_insert_delete(tmpdir, toolchain): for f0 in [-0.5, 0.5, 1.5, np.nan]: for f1 in [0, 1, 2, 3, 4, np.nan]: for f2 in [-1.0, -0.5, 1.0, np.nan]: - x = np.array([f0, f1, f2]) - pred = predictor.predict_instance(x) + x = np.array([[f0, f1, f2]]) + dmat = treelite_runtime.DMatrix(x, dtype='float32') + pred = predictor.predict(dmat) if f1 in [1, 2, 4] or np.isnan(f1): expected_pred = 2.0 elif f0 <= 0.5 and not np.isnan(f0): diff --git a/tests/python/test_single_inst.py b/tests/python/test_single_inst.py deleted file mode 100644 index 10040bce..00000000 --- a/tests/python/test_single_inst.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- coding: utf-8 -*- -"""Tests for single-instance prediction""" -import os - -import pytest -import numpy as np -import treelite -import treelite_runtime -from treelite.util import has_sklearn -from treelite.contrib import _libext -from .metadata import dataset_db -from .util import os_compatible_toolchains, check_predictor_output - - -@pytest.mark.skipif(not has_sklearn(), reason='Needs scikit-learn') -@pytest.mark.parametrize('toolchain', os_compatible_toolchains()) -@pytest.mark.parametrize('dataset', ['mushroom', 'dermatology', 'toy_categorical']) -def test_single_inst(tmpdir, annotation, dataset, toolchain): - # pylint: disable=too-many-locals - """Run end-to-end test""" - libpath = os.path.join(tmpdir, dataset_db[dataset].libname + _libext()) - model = treelite.Model.load(dataset_db[dataset].model, model_format=dataset_db[dataset].format) - annotation_path = os.path.join(tmpdir, 'annotation.json') - - if annotation[dataset] is None: - annotation_path = None - else: - with open(annotation_path, 'w') as f: - f.write(annotation[dataset]) - - params = { - 'annotate_in': (annotation_path if annotation_path else 'NULL'), - 'quantize': 1, 'parallel_comp': model.num_tree - } - model.export_lib(toolchain=toolchain, libpath=libpath, params=params, verbose=True) - predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True) - - from sklearn.datasets import load_svmlight_file - - X_test, _ = load_svmlight_file(dataset_db[dataset].dtest, zero_based=True) - out_prob = [[] for _ in range(4)] - out_margin = [[] for _ in range(4)] - for i in range(X_test.shape[0]): - x = X_test[i, :] - # Scipy CSR matrix - out_prob[0].append(predictor.predict_instance(x)) - out_margin[0].append(predictor.predict_instance(x, pred_margin=True)) - # NumPy 1D array with 0 as missing value - x = x.toarray().flatten() - out_prob[1].append(predictor.predict_instance(x, missing=0.0)) - out_margin[1].append(predictor.predict_instance(x, missing=0.0, pred_margin=True)) - # NumPy 1D array with np.nan as missing value - np.place(x, x == 0.0, [np.nan]) - out_prob[2].append(predictor.predict_instance(x, missing=np.nan)) - out_margin[2].append(predictor.predict_instance(x, missing=np.nan, pred_margin=True)) - # NumPy 1D array with np.nan as missing value - # (default when `missing` parameter is unspecified) - out_prob[3].append(predictor.predict_instance(x)) - out_margin[3].append(predictor.predict_instance(x, pred_margin=True)) - - for i in range(4): - check_predictor_output(dataset, X_test.shape, - out_margin=np.squeeze(np.array(out_margin[i])), - out_prob=np.squeeze(np.array(out_prob[i]))) diff --git a/tests/python/test_xgboost_integration.py b/tests/python/test_xgboost_integration.py index 108234dd..d5ea2247 100644 --- a/tests/python/test_xgboost_integration.py +++ b/tests/python/test_xgboost_integration.py @@ -49,8 +49,8 @@ def test_xgb_boston(tmpdir, toolchain): # pylint: disable=too-many-locals assert predictor.pred_transform == 'identity' assert predictor.global_bias == 0.5 assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X_test) - out_pred = predictor.predict(batch) + dmat = treelite_runtime.DMatrix(X_test, dtype='float32') + out_pred = predictor.predict(dmat) expected_pred = bst.predict(dtest) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) @@ -86,8 +86,8 @@ def test_xgb_iris(tmpdir, toolchain): # pylint: disable=too-many-locals assert predictor.pred_transform == 'max_index' assert predictor.global_bias == 0.5 assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X_test) - out_pred = predictor.predict(batch) + dmat = treelite_runtime.DMatrix(X_test, dtype='float32') + out_pred = predictor.predict(dmat) expected_pred = bst.predict(dtest) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) @@ -126,7 +126,7 @@ def test_nonlinear_objective(tmpdir, objective, max_label, global_bias, toolchai assert predictor.pred_transform == expected_pred_transform[objective] np.testing.assert_almost_equal(predictor.global_bias, global_bias, decimal=5) assert predictor.sigmoid_alpha == 1.0 - batch = treelite_runtime.Batch.from_npy2d(X) - out_pred = predictor.predict(batch) + dmat = treelite_runtime.DMatrix(X, dtype='float32') + out_pred = predictor.predict(dmat) expected_pred = bst.predict(dtrain) np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5) diff --git a/tests/python/util.py b/tests/python/util.py index b91f2ad0..c0a392c1 100644 --- a/tests/python/util.py +++ b/tests/python/util.py @@ -5,7 +5,6 @@ from contextlib import contextmanager import numpy as np -import treelite import treelite_runtime from treelite.contrib import _libext from .metadata import dataset_db @@ -56,11 +55,10 @@ def does_not_raise(): def check_predictor(predictor, dataset): """Check whether a predictor produces correct predictions for a given dataset""" - dtest = treelite.DMatrix(dataset_db[dataset].dtest) - batch = treelite_runtime.Batch.from_csr(dtest) - out_margin = predictor.predict(batch, pred_margin=True) - out_prob = predictor.predict(batch) - check_predictor_output(dataset, dtest.shape, out_margin, out_prob) + dmat = treelite_runtime.DMatrix(dataset_db[dataset].dtest, dtype=dataset_db[dataset].dtype) + out_margin = predictor.predict(dmat, pred_margin=True) + out_prob = predictor.predict(dmat) + check_predictor_output(dataset, dmat.shape, out_margin, out_prob) def check_predictor_output(dataset, shape, out_margin, out_prob): @@ -68,6 +66,8 @@ def check_predictor_output(dataset, shape, out_margin, out_prob): expected_margin = load_txt(dataset_db[dataset].expected_margin) if dataset_db[dataset].is_multiclass: expected_margin = expected_margin.reshape((shape[0], -1)) + assert out_margin.shape == expected_margin.shape, \ + f'out_margin.shape = {out_margin.shape}, expected_margin.shape = {expected_margin.shape}' np.testing.assert_almost_equal(out_margin, expected_margin, decimal=5) if dataset_db[dataset].expected_prob is not None: From 7aef6dbc580193193fd289a608db666ceaf4db9c Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 3 Sep 2020 22:29:24 -0700 Subject: [PATCH 18/38] Emit correct data type for leaf outputs --- src/compiler/ast_native.cc | 12 ++++++++---- tests/python/test_model_builder.py | 3 --- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index e0adb997..e74fc64c 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -554,6 +554,8 @@ class ASTNativeCompiler : public Compiler { template inline std::string ExtractNumericalCondition(const NumericalConditionNode* node) { + const std::string threshold_type + = native::TypeInfoToCTypeString(InferTypeInfoOf()); std::string result; if (node->quantized) { // quantized threshold result = fmt::format("data[{split_index}].qvalue {opname} {threshold}", @@ -566,10 +568,12 @@ class ASTNativeCompiler : public Compiler { result = (CompareWithOp(static_cast(0), node->op, node->threshold.float_val) ? "1" : "0"); } else { // finite threshold - result = fmt::format("data[{split_index}].fvalue {opname} (float){threshold}", - "split_index"_a = node->split_index, - "opname"_a = OpName(node->op), - "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val)); + result + = fmt::format("data[{split_index}].fvalue {opname} ({threshold_type}){threshold}", + "split_index"_a = node->split_index, + "opname"_a = OpName(node->op), + "threshold_type"_a = threshold_type, + "threshold"_a = common_util::ToStringHighPrecision(node->threshold.float_val)); } return result; } diff --git a/tests/python/test_model_builder.py b/tests/python/test_model_builder.py index 6e31cf30..b5fbe3e0 100644 --- a/tests/python/test_model_builder.py +++ b/tests/python/test_model_builder.py @@ -119,9 +119,6 @@ def test_model_builder(tmpdir, use_annotation, quantize, toolchain): def test_skl_converter_multiclass_classifier(tmpdir, clazz, toolchain): # pylint: disable=too-many-locals """Convert scikit-learn multi-class classifier""" - #if clazz == RandomForestClassifier: - # pytest.xfail(reason='Need to use float64 thresholds to obtain correct result') - X, y = load_iris(return_X_y=True) kwargs = {} if clazz == GradientBoostingClassifier: From f2715d9febea33819306ea3f52eff14d7af839d9 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 3 Sep 2020 22:40:01 -0700 Subject: [PATCH 19/38] Remove Model Type; consolidate type dispatching logic --- include/treelite/tree.h | 13 ------ include/treelite/tree_impl.h | 91 +++++++----------------------------- src/compiler/failsafe.cc | 3 +- 3 files changed, 20 insertions(+), 87 deletions(-) diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 3660929b..9d85d9f4 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -419,18 +419,8 @@ static_assert(std::is_standard_layout::value, inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg); -enum class ModelType : uint16_t { - // Threshold type, - kInvalid = 0, - kFloat32ThresholdUInt32LeafOutput = 1, - kFloat32ThresholdFloat32LeafOutput = 2, - kFloat64ThresholdUInt32LeafOutput = 3, - kFloat64ThresholdFloat64LeafOutput = 4 -}; - class Model { private: - ModelType type_; TypeInfo threshold_type_; TypeInfo leaf_output_type_; virtual void GetPyBuffer(std::vector* dest) = 0; @@ -441,11 +431,8 @@ class Model { Model() = default; virtual ~Model() = default; template - inline static ModelType InferModelTypeOf(); - template inline static std::unique_ptr Create(); inline static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type); - inline ModelType GetModelType() const; inline TypeInfo GetThresholdType() const; inline TypeInfo GetLeafOutputType() const; template diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 8dc1dc02..b88ce6ef 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -754,11 +754,6 @@ Tree::SetGain(int nid, double gain) { node.gain_present_ = true; } -inline ModelType -Model::GetModelType() const { - return type_; -} - inline TypeInfo Model::GetThresholdType() const { return threshold_type_; @@ -769,53 +764,17 @@ Model::GetLeafOutputType() const { return leaf_output_type_; } -template -inline ModelType -Model::InferModelTypeOf() { - const std::string error_msg - = std::string("Unsupported combination of ThresholdType (") - + TypeInfoToString(InferTypeInfoOf()) + ") and LeafOutputType (" - + TypeInfoToString(InferTypeInfoOf()) + ")"; - static_assert(std::is_same::value - || std::is_same::value, - "ThresholdType should be either float32 or float64"); - static_assert(std::is_same::value - || std::is_same::value - || std::is_same::value, - "LeafOutputType should be uint32, float32 or float64"); - if (std::is_same::value) { - if (std::is_same::value) { - return ModelType::kFloat32ThresholdUInt32LeafOutput; - } else if (std::is_same::value) { - return ModelType::kFloat32ThresholdFloat32LeafOutput; - } else { - throw std::runtime_error(error_msg); - } - } else if (std::is_same::value) { - if (std::is_same::value) { - return ModelType::kFloat64ThresholdUInt32LeafOutput; - } else if (std::is_same::value) { - return ModelType::kFloat64ThresholdFloat64LeafOutput; - } else { - throw std::runtime_error(error_msg); - } - } - throw std::runtime_error(error_msg); - return ModelType::kInvalid; -} - template inline std::unique_ptr Model::Create() { std::unique_ptr model = std::make_unique>(); - model->type_ = InferModelTypeOf(); model->threshold_type_ = InferTypeInfoOf(); model->leaf_output_type_ = InferTypeInfoOf(); return model; } template -class ModelCreateDispatcher { +class ModelCreateImpl { public: inline static std::unique_ptr Dispatch() { return Model::Create(); @@ -824,47 +783,33 @@ class ModelCreateDispatcher { inline std::unique_ptr Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { - return DispatchWithModelTypes(threshold_type, leaf_output_type); + return DispatchWithModelTypes(threshold_type, leaf_output_type); } +template +class ModelDispatchImpl { + public: + template + inline static auto Dispatch(Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } + + template + inline static auto Dispatch(const Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } +}; + template inline auto Model::Dispatch(Func func) { - switch (type_) { - case ModelType::kFloat32ThresholdUInt32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat32ThresholdFloat32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat64ThresholdUInt32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat64ThresholdFloat64LeafOutput: - return func(*dynamic_cast*>(this)); - default: - throw std::runtime_error(std::string("Unknown type detected: ") - + std::to_string(static_cast(type_))); - return func(*dynamic_cast*>(this)); - // avoid "missing return" warning - } + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); } template inline auto Model::Dispatch(Func func) const { - switch (type_) { - case ModelType::kFloat32ThresholdUInt32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat32ThresholdFloat32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat64ThresholdUInt32LeafOutput: - return func(*dynamic_cast*>(this)); - case ModelType::kFloat64ThresholdFloat64LeafOutput: - return func(*dynamic_cast*>(this)); - default: - throw std::runtime_error(std::string("Unknown type detected: ") - + std::to_string(static_cast(type_))); - return func(*dynamic_cast*>(this)); - // avoid "missing return" warning - } + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); } inline std::vector diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index 67b428f1..7806399d 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -260,7 +260,8 @@ class FailSafeCompiler : public Compiler { } CompiledModel Compile(const Model& model) override { - CHECK(model.GetModelType() == ModelType::kFloat32ThresholdFloat32LeafOutput) + CHECK(model.GetThresholdType() == TypeInfo::kFloat32 + && model.GetLeafOutputType() == TypeInfo::kFloat32) << "Failsafe compiler only supports models with float32 thresholds and float32 leaf outputs"; const auto& model_handle = dynamic_cast&>(model); From 03498f0d28c71daaacb10a9bdace96c0ff28aed7 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 01:23:58 -0700 Subject: [PATCH 20/38] Fix all scikit-learn tests: input must be float32 + thresholds are float64 --- include/treelite/base.h | 4 +- include/treelite/data.h | 2 - src/annotator.cc | 82 ++++++++++++++++++------------ src/data/data.cc | 5 ++ src/predictor/predictor.cc | 63 +++++++++++++++-------- tests/python/test_model_builder.py | 2 +- 6 files changed, 100 insertions(+), 58 deletions(-) diff --git a/include/treelite/base.h b/include/treelite/base.h index 748473f6..56539cb7 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -59,8 +59,8 @@ inline std::string OpName(Operator op) { * \param rhs float on the right hand side * \return whether [lhs] [op] [rhs] is true or not */ -template -inline bool CompareWithOp(ThresholdType lhs, Operator op, ThresholdType rhs) { +template +inline bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs) { switch (op) { case Operator::kEQ: return lhs == rhs; case Operator::kLT: return lhs < rhs; diff --git a/include/treelite/data.h b/include/treelite/data.h index 0c51e355..6ca682b2 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -77,8 +77,6 @@ class DenseDMatrixImpl : public DenseDMatrix { DMatrixType GetType() const override; friend class DenseDMatrix; - static_assert(std::is_same::value || std::is_same::value, - "ElementType must be either float32 or float64"); }; class CSRDMatrix : public DMatrix { diff --git a/src/annotator.cc b/src/annotator.cc index 49f30bb9..1c7fa7ba 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -14,15 +14,15 @@ namespace { -template +template union Entry { int missing; - ThresholdType fvalue; + ElementType fvalue; }; -template +template void Traverse_(const treelite::Tree& tree, - const Entry* data, int nid, size_t* out_counts) { + const Entry* data, int nid, size_t* out_counts) { ++out_counts[nid]; if (!tree.IsLeaf(nid)) { const unsigned split_index = tree.SplitIndex(nid); @@ -34,7 +34,7 @@ void Traverse_(const treelite::Tree& tree, if (tree.SplitType(nid) == treelite::SplitFeatureType::kNumerical) { const ThresholdType threshold = tree.Threshold(nid); const treelite::Operator op = tree.ComparisonOp(nid); - const auto fvalue = static_cast(data[split_index].fvalue); + const auto fvalue = static_cast(data[split_index].fvalue); result = treelite::CompareWithOp(fvalue, op, threshold); } else { const auto fvalue = data[split_index].fvalue; @@ -52,17 +52,18 @@ void Traverse_(const treelite::Tree& tree, } } -template +template void Traverse(const treelite::Tree& tree, - const Entry* data, size_t* out_counts) { + const Entry* data, size_t* out_counts) { Traverse_(tree, data, 0, out_counts); } -template +template inline void ComputeBranchLoopImpl( const treelite::ModelImpl& model, - const treelite::DenseDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { + const treelite::DenseDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc) { + std::vector> inst(nthread * dmat->num_col, {-1}); const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -74,7 +75,7 @@ inline void ComputeBranchLoopImpl( #pragma omp parallel for schedule(static) num_threads(nthread) for (int64_t rid = rbegin_i; rid < rend_i; ++rid) { const int tid = omp_get_thread_num(); - const ThresholdType* row = &dmat->data[rid * num_col]; + const ElementType* row = &dmat->data[rid * num_col]; const size_t off = dmat->num_col * tid; const size_t off2 = count_row_ptr[ntree] * tid; for (size_t j = 0; j < num_col; ++j) { @@ -94,11 +95,12 @@ inline void ComputeBranchLoopImpl( } } -template +template inline void ComputeBranchLoopImpl( const treelite::ModelImpl& model, - const treelite::CSRDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, - const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { + const treelite::CSRDMatrixImpl* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc) { + std::vector> inst(nthread * dmat->num_col, {-1}); const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -123,26 +125,48 @@ inline void ComputeBranchLoopImpl( } } +template +class ComputeBranchLoopDispatcherWithDenseDMatrix { + public: + template + inline static void Dispatch( + const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc) { + const auto* dmat_ = static_cast*>(dmat); + CHECK(dmat_) << "Dangling data matrix reference detected"; + ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc); + } +}; + +template +class ComputeBranchLoopDispatcherWithCSRDMatrix { + public: + template + inline static void Dispatch( + const treelite::ModelImpl& model, + const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, + const size_t* count_row_ptr, size_t* counts_tloc) { + const auto* dmat_ = static_cast*>(dmat); + CHECK(dmat_) << "Dangling data matrix reference detected"; + ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc); + } +}; + template inline void ComputeBranchLoop(const treelite::ModelImpl& model, const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, const size_t* count_row_ptr, - size_t* counts_tloc, Entry* inst) { - CHECK(dmat->GetElementType() == treelite::InferTypeInfoOf()) - << "DMatrix has a wrong type. DMatrix has " - << treelite::TypeInfoToString(dmat->GetElementType()) << " whereas the model expects " - << treelite::TypeInfoToString(treelite::InferTypeInfoOf()); + size_t* counts_tloc) { switch (dmat->GetType()) { case treelite::DMatrixType::kDense: { - const auto* dmat_ = dynamic_cast*>(dmat); - CHECK(dmat_) << "Dangling data matrix reference detected"; - ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc, inst); + treelite::DispatchWithTypeInfo( + dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc); break; } case treelite::DMatrixType::kSparseCSR: { - const auto* dmat_ = dynamic_cast*>(dmat); - CHECK(dmat_) << "Dangling data matrix reference detected"; - ComputeBranchLoopImpl(model, dmat_, rbegin, rend, nthread, count_row_ptr, counts_tloc, inst); + treelite::DispatchWithTypeInfo( + dmat->GetElementType(), model, dmat, rbegin, rend, nthread, count_row_ptr, counts_tloc); break; } default: @@ -178,13 +202,11 @@ AnnotateImpl( const size_t num_row = dmat->GetNumRow(); const size_t num_col = dmat->GetNumCol(); - std::vector> inst(nthread * num_col, {-1}); const size_t pstep = (num_row + 19) / 20; // interval to display progress for (size_t rbegin = 0; rbegin < num_row; rbegin += pstep) { const size_t rend = std::min(rbegin + pstep, num_row); - ComputeBranchLoop(model, dmat, rbegin, rend, nthread, - &count_row_ptr[0], &counts_tloc[0], &inst[0]); + ComputeBranchLoop(model, dmat, rbegin, rend, nthread, &count_row_ptr[0], &counts_tloc[0]); if (verbose > 0) { LOG(INFO) << rend << " of " << num_row << " rows processed"; } @@ -209,10 +231,6 @@ void BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose) { TypeInfo threshold_type = model.GetThresholdType(); model.Dispatch([this, dmat, nthread, verbose, threshold_type](auto& handle) { - CHECK(dmat->GetElementType() == threshold_type) - << "BranchAnnotator: the matrix type must match the threshold type of the model." - << "(current matrix type = " << TypeInfoToString(dmat->GetElementType()) - << " vs threshold type = " << TypeInfoToString(threshold_type) << ")"; AnnotateImpl(handle, dmat, nthread, verbose, &this->counts); }); } diff --git a/src/data/data.cc b/src/data/data.cc index d8baa758..1edf0f03 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -250,4 +250,9 @@ CSRDMatrixImpl::GetType() const { return DMatrixType::kSparseCSR; } +template class DenseDMatrixImpl; +template class DenseDMatrixImpl; +template class CSRDMatrixImpl; +template class CSRDMatrixImpl; + } // namespace treelite diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 96166d01..ac98e851 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -44,11 +44,11 @@ struct OutputToken { using PredThreadPool = treelite::predictor::ThreadPool; -template +template inline size_t PredLoop(const treelite::CSRDMatrixImpl* dmat, int num_feature, size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { CHECK_LE(dmat->num_col, static_cast(num_feature)); - std::vector> inst( + std::vector> inst( std::max(dmat->num_col, static_cast(num_feature)), {-1}); CHECK(rbegin < rend && rend <= dmat->num_row); const size_t num_col = dmat->num_col; @@ -70,12 +70,12 @@ inline size_t PredLoop(const treelite::CSRDMatrixImpl* dmat, int nu return total_output_size; } -template +template inline size_t PredLoop(const treelite::DenseDMatrixImpl* dmat, int num_feature, size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { const bool nan_missing = treelite::math::CheckNAN(dmat->missing_value); CHECK_LE(dmat->num_col, static_cast(num_feature)); - std::vector> inst( + std::vector> inst( std::max(dmat->num_col, static_cast(num_feature)), {-1}); CHECK(rbegin < rend && rend <= dmat->num_row); const size_t num_col = dmat->num_col; @@ -101,18 +101,46 @@ inline size_t PredLoop(const treelite::DenseDMatrixImpl* dmat, int return total_output_size; } -template -inline size_t PredLoop(const treelite::DMatrix* dmat, int num_feature, +template +class PredLoopDispatcherWithDenseDMatrix { + public: + template + inline static size_t Dispatch( + const treelite::DMatrix* dmat, ThresholdType test_val, + int num_feature, size_t rbegin, size_t rend, + LeafOutputType* out_pred, PredFunc func) { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop( + dmat_, num_feature, rbegin, rend, out_pred, func); + } +}; + +template +class PredLoopDispatcherWithCSRDMatrix { + public: + template + inline static size_t Dispatch( + const treelite::DMatrix* dmat, ThresholdType test_val, + int num_feature, size_t rbegin, size_t rend, + LeafOutputType* out_pred, PredFunc func) { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop( + dmat_, num_feature, rbegin, rend, out_pred, func); + } +}; + +template +inline size_t PredLoop(const treelite::DMatrix* dmat, ThresholdType test_val, int num_feature, size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { treelite::DMatrixType dmat_type = dmat->GetType(); switch (dmat_type) { case treelite::DMatrixType::kDense: { - const auto* dmat_ = static_cast*>(dmat); - return PredLoop(dmat_, num_feature, rbegin, rend, out_pred, func); + return treelite::DispatchWithTypeInfo( + dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func); } case treelite::DMatrixType::kSparseCSR: { - const auto* dmat_ = static_cast*>(dmat); - return PredLoop(dmat_, num_feature, rbegin, rend, out_pred, func); + return treelite::DispatchWithTypeInfo( + dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func); } default: LOG(FATAL) << "Unrecognized data matrix type: " << static_cast(dmat_type); @@ -225,9 +253,6 @@ PredFunctionImpl::PredictBatch( // can be either [num_data] or [num_class]*[num_data]. // Note that size of prediction may be smaller than out_pred (this occurs // when pred_function is set to "max_index"). - CHECK(dmat->GetElementType() == GetThresholdType()) - << "Mismatched data type in the data matrix. Expected: " << TypeInfoToString(GetThresholdType()) - << ", Given: " << TypeInfoToString(dmat->GetElementType()); CHECK(rbegin < rend && rend <= dmat->GetNumRow()); size_t num_row = rend - rbegin; if (num_output_group_ > 1) { // multi-class classification @@ -240,10 +265,8 @@ PredFunctionImpl::PredictBatch( return pred_func(inst, static_cast(pred_margin), &out_pred[rid * num_output_group]); }; - result_size = - PredLoop( - dmat, num_feature_, rbegin, rend, static_cast(out_pred), - pred_func_wrapper); + result_size = PredLoop(dmat, static_cast(0), num_feature_, rbegin, rend, + static_cast(out_pred), pred_func_wrapper); } else { // everything else using PredFunc = LeafOutputType (*)(Entry*, int); auto pred_func = reinterpret_cast(handle_); @@ -254,10 +277,8 @@ PredFunctionImpl::PredictBatch( out_pred[rid] = pred_func(inst, static_cast(pred_margin)); return 1; }; - result_size = - PredLoop( - dmat, num_feature_, rbegin, rend, static_cast(out_pred), - pred_func_wrapper); + result_size = PredLoop(dmat, static_cast(0), num_feature_, rbegin, rend, + static_cast(out_pred), pred_func_wrapper); } return result_size; } diff --git a/tests/python/test_model_builder.py b/tests/python/test_model_builder.py index b5fbe3e0..0159757c 100644 --- a/tests/python/test_model_builder.py +++ b/tests/python/test_model_builder.py @@ -210,7 +210,7 @@ def test_skl_converter_regressor(tmpdir, clazz, toolchain): # pylint: disable=t assert model.num_output_group == 1 assert model.num_tree == clf.n_estimators - dtrain = treelite_runtime.DMatrix(X, dtype='float64') + dtrain = treelite_runtime.DMatrix(X, dtype='float32') annotation_path = os.path.join(tmpdir, 'annotation.json') annotator = treelite.Annotator() annotator.annotate_branch(model=model, dmat=dtrain, verbose=True) From 88080ec68daeac27607f31e495c5d3820815007a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 01:47:14 -0700 Subject: [PATCH 21/38] Fix lint --- include/treelite/predictor.h | 1 + include/treelite/tree_impl.h | 1 + include/treelite/typeinfo.h | 1 + python/treelite/annotator.py | 2 +- python/treelite/core.py | 3 --- python/treelite/frontend.py | 26 +++++++------------- python/treelite/util.py | 34 +++++++++++++++++++++++++++ src/compiler/native/typeinfo_ctypes.h | 12 ++++++---- src/data/data.cc | 9 +++---- src/optable.cc | 2 +- src/predictor/predictor.cc | 5 ++-- src/typeinfo.cc | 2 +- tests/python/test_basic.py | 1 - 13 files changed, 61 insertions(+), 38 deletions(-) diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 52626d76..8aaab455 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace treelite { diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index b88ce6ef..c5fdacd1 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index c10428b6..90cb645e 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace treelite { diff --git a/python/treelite/annotator.py b/python/treelite/annotator.py index 14ec8c2d..53674e3c 100644 --- a/python/treelite/annotator.py +++ b/python/treelite/annotator.py @@ -2,10 +2,10 @@ """branch annotator module""" import ctypes +from treelite_runtime import DMatrix from .util import c_str, TreeliteError from .core import _LIB, _check_call from .frontend import Model -from treelite_runtime import DMatrix class Annotator(): diff --git a/python/treelite/core.py b/python/treelite/core.py index ccd7fde2..120a612f 100644 --- a/python/treelite/core.py +++ b/python/treelite/core.py @@ -5,9 +5,6 @@ import sys import ctypes -import numpy as np -import scipy.sparse - from .util import py_str, _log_callback, TreeliteError from .libpath import find_lib_path, TreeliteLibraryNotFound diff --git a/python/treelite/frontend.py b/python/treelite/frontend.py index 7413c6aa..b7d6164d 100644 --- a/python/treelite/frontend.py +++ b/python/treelite/frontend.py @@ -9,7 +9,7 @@ import numpy as np -from .util import c_str, TreeliteError +from .util import c_str, TreeliteError, type_info_to_ctypes_type, type_info_to_numpy_type from .core import _LIB, c_array, _check_call from .contrib import create_shared, generate_makefile, generate_cmakelists, _toolchain_exist_check @@ -400,27 +400,16 @@ class Value: Parameters ---------- - type : str + dtype : str Initial value of model handle """ - CTYPES_PTR = { - 'uint32': ctypes.c_uint32, - 'float32': ctypes.c_float, - 'float64': ctypes.c_double - } - NUMPY_TYPE = { - 'uint32': np.uint32, - 'float32': np.float32, - 'float64': np.float64 - } - - def __init__(self, init_value, type): - self.type = type + def __init__(self, init_value, dtype): + self.type = dtype self.handle = ctypes.c_void_p() - val = np.array([init_value], dtype=self.NUMPY_TYPE[type], order='C') + val = np.array([init_value], dtype=type_info_to_numpy_type(dtype), order='C') _check_call(_LIB.TreeliteTreeBuilderCreateValue( - val.ctypes.data_as(ctypes.POINTER(self.CTYPES_PTR[type])), - c_str(type), + val.ctypes.data_as(ctypes.POINTER(type_info_to_ctypes_type(dtype))), + c_str(dtype), ctypes.byref(self.handle) )) @@ -692,6 +681,7 @@ def __repr__(self): return '\n' \ .format(len(self.nodes)) + # pylint: disable=R0913 def __init__(self, num_feature, num_output_group=1, random_forest=False, threshold_type='float32', leaf_output_type='float32', **kwargs): if not isinstance(num_feature, int): diff --git a/python/treelite/util.py b/python/treelite/util.py index 70a56282..02a5046e 100644 --- a/python/treelite/util.py +++ b/python/treelite/util.py @@ -5,6 +5,25 @@ import inspect import ctypes import time +import numpy as np + +_CTYPES_TYPE_TABLE = { + 'uint32': ctypes.c_uint32, + 'float32': ctypes.c_float, + 'float64': ctypes.c_double +} + +_NUMPY_TYPE_TABLE = { + 'uint32': np.uint32, + 'float32': np.float32, + 'float64': np.float64 +} + +_NUMPY_TYPE_TABLE_INV = { + np.uint32: 'unit32', + np.float32: 'float32', + np.float64: 'float64' +} class TreeliteError(Exception): @@ -52,3 +71,18 @@ def has_sklearn(): return True except ImportError: return False + + +def type_info_to_ctypes_type(type_info): + """Obtain ctypes type corresponding to a given TypeInfo""" + return _CTYPES_TYPE_TABLE[type_info] + + +def type_info_to_numpy_type(type_info): + """Obtain ctypes type corresponding to a given TypeInfo""" + return _NUMPY_TYPE_TABLE[type_info] + + +def numpy_type_to_type_info(type_info): + """Obtain TypeInfo corresponding to a given NumPy type""" + return _NUMPY_TYPE_TABLE_INV[type_info] diff --git a/src/compiler/native/typeinfo_ctypes.h b/src/compiler/native/typeinfo_ctypes.h index ef41f44c..ce965b91 100644 --- a/src/compiler/native/typeinfo_ctypes.h +++ b/src/compiler/native/typeinfo_ctypes.h @@ -1,12 +1,16 @@ -// -// Created by Philip Hyunsu Cho on 8/31/20. -// +/*! + * Copyright (c) 2020 by Contributors + * \file typeinfo_ctypes.h + * \author Hyunsu Cho + * \brief Look up C symbols corresponding to TypeInfo + */ + #ifndef TREELITE_COMPILER_NATIVE_TYPEINFO_CTYPES_H_ #define TREELITE_COMPILER_NATIVE_TYPEINFO_CTYPES_H_ -#include #include +#include namespace treelite { namespace compiler { diff --git a/src/data/data.cc b/src/data/data.cc index 1edf0f03..856ac833 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -96,8 +96,7 @@ std::unique_ptr DenseDMatrix::Create( std::vector data, ElementType missing_value, size_t num_row, size_t num_col) { std::unique_ptr matrix = std::make_unique>( - std::move(data), missing_value, num_row, num_col - ); + std::move(data), missing_value, num_row, num_col); matrix->element_type_ = InferTypeInfoOf(); return matrix; } @@ -168,8 +167,7 @@ std::unique_ptr CSRDMatrix::Create(std::vector data, std::vector col_ind, std::vector row_ptr, size_t num_row, size_t num_col) { std::unique_ptr matrix = std::make_unique>( - std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col - ); + std::move(data), std::move(col_ind), std::move(row_ptr), num_row, num_col); matrix->element_type_ = InferTypeInfoOf(); return matrix; } @@ -185,8 +183,7 @@ CSRDMatrix::Create(const void* data, const uint32_t* col_ind, std::vector(col_ind, col_ind + num_elem), std::vector(row_ptr, row_ptr + num_row + 1), num_row, - num_col - ); + num_col); } std::unique_ptr diff --git a/src/optable.cc b/src/optable.cc index af4e4581..44b6d5e6 100644 --- a/src/optable.cc +++ b/src/optable.cc @@ -5,9 +5,9 @@ * \brief Conversion tables to obtain Operator from string */ +#include #include #include -#include namespace treelite { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index ac98e851..9fa783d5 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -33,7 +33,7 @@ struct InputToken { const treelite::DMatrix* dmat; // input data bool pred_margin; // whether to store raw margin or transformed scores const treelite::predictor::PredFunction* pred_func_; - size_t rbegin, rend; // range of instances (rows) assigned to each worker + size_t rbegin, rend; // range of instances (rows) assigned to each worker void* out_pred; // buffer to store output from each worker }; @@ -146,7 +146,6 @@ inline size_t PredLoop(const treelite::DMatrix* dmat, ThresholdType test_val, in LOG(FATAL) << "Unrecognized data matrix type: " << static_cast(dmat_type); return 0; } - } } // anonymous namespace @@ -260,7 +259,7 @@ PredFunctionImpl::PredictBatch( auto pred_func = reinterpret_cast(handle_); CHECK(pred_func) << "The predict_multiclass() function has incorrect signature."; auto pred_func_wrapper - = [pred_func, num_output_group=num_output_group_, pred_margin] + = [pred_func, num_output_group = num_output_group_, pred_margin] (int64_t rid, Entry* inst, LeafOutputType* out_pred) -> size_t { return pred_func(inst, static_cast(pred_margin), &out_pred[rid * num_output_group]); diff --git a/src/typeinfo.cc b/src/typeinfo.cc index a58b2904..048c0e79 100644 --- a/src/typeinfo.cc +++ b/src/typeinfo.cc @@ -7,9 +7,9 @@ // Do not include other Treelite headers here, to minimize cross-header dependencies +#include #include #include -#include namespace treelite { diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index d9151b43..967c6f71 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -7,7 +7,6 @@ from zipfile import ZipFile import pytest -import numpy as np from scipy.sparse import csr_matrix import treelite import treelite_runtime From 7676edf3a67944b54324cc8bd9ede5f4f541d2ab Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 02:01:27 -0700 Subject: [PATCH 22/38] Expose type info as property of Predictor --- runtime/python/treelite_runtime/predictor.py | 10 ++++++++++ tests/python/test_model_query.py | 16 ++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/runtime/python/treelite_runtime/predictor.py b/runtime/python/treelite_runtime/predictor.py index 02da0143..5e901b3f 100644 --- a/runtime/python/treelite_runtime/predictor.py +++ b/runtime/python/treelite_runtime/predictor.py @@ -218,6 +218,16 @@ def sigmoid_alpha(self): """Query sigmoid alpha of the model""" return self.sigmoid_alpha_ + @property + def threshold_type(self): + """Query threshold type of the model""" + return self.threshold_type_ + + @property + def leaf_output_type(self): + """Query threshold type of the model""" + return self.leaf_output_type_ + class DMatrix: """Data matrix used in Treelite. diff --git a/tests/python/test_model_query.py b/tests/python/test_model_query.py index bf040013..07a05954 100644 --- a/tests/python/test_model_query.py +++ b/tests/python/test_model_query.py @@ -12,13 +12,15 @@ from .util import os_platform, os_compatible_toolchains ModelFact = collections.namedtuple( - 'ModelFact', 'num_tree num_feature num_output_group pred_transform global_bias sigmoid_alpha') + 'ModelFact', + 'num_tree num_feature num_output_group pred_transform global_bias sigmoid_alpha ' + 'threshold_type leaf_output_type') _model_facts = { - 'mushroom': ModelFact(2, 127, 1, 'sigmoid', 0.0, 1.0), - 'dermatology': ModelFact(60, 33, 6, 'softmax', 0.5, 1.0), - 'letor': ModelFact(713, 47, 1, 'identity', 0.5, 1.0), - 'toy_categorical': ModelFact(30, 2, 1, 'identity', 0.0, 1.0), - 'sparse_categorical': ModelFact(1, 5057, 1, 'sigmoid', 0.0, 1.0) + 'mushroom': ModelFact(2, 127, 1, 'sigmoid', 0.0, 1.0, 'float32', 'float32'), + 'dermatology': ModelFact(60, 33, 6, 'softmax', 0.5, 1.0, 'float32', 'float32'), + 'letor': ModelFact(713, 47, 1, 'identity', 0.5, 1.0, 'float32', 'float32'), + 'toy_categorical': ModelFact(30, 2, 1, 'identity', 0.0, 1.0, 'float64', 'float64'), + 'sparse_categorical': ModelFact(1, 5057, 1, 'sigmoid', 0.0, 1.0, 'float64', 'float64') } @@ -46,3 +48,5 @@ def test_model_query(tmpdir, dataset): assert predictor.pred_transform == _model_facts[dataset].pred_transform assert predictor.global_bias == _model_facts[dataset].global_bias assert predictor.sigmoid_alpha == _model_facts[dataset].sigmoid_alpha + assert predictor.threshold_type == _model_facts[dataset].threshold_type + assert predictor.leaf_output_type == _model_facts[dataset].leaf_output_type From 30ef99abba633cf63475847469aeef5ebd379563 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 02:15:58 -0700 Subject: [PATCH 23/38] Fix build with CentOS 6 + GCC 5 --- include/treelite/typeinfo.h | 1 + src/compiler/ast_native.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index 90cb645e..7b1337ea 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -9,6 +9,7 @@ #define TREELITE_TYPEINFO_H_ #include +#include #include #include #include diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index e74fc64c..256f63f1 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -124,7 +124,7 @@ class ASTNativeCompiler : public Compiler { CompiledModel Compile(const Model& model) override { this->pred_tranform_func_ = PredTransformFunction("native", model); return model.Dispatch([this](const auto& model_handle) { - return CompileImpl(model_handle); + return this->CompileImpl(model_handle); }); } From d0659f249897e9dd55657e3c93970253ac55e727 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 02:17:12 -0700 Subject: [PATCH 24/38] Move data/data.cc -> data.cc --- src/CMakeLists.txt | 2 +- src/{data => }/data.cc | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{data => }/data.cc (100%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index dd260151..6be940fc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -117,7 +117,7 @@ target_sources(objtreelite_common c_api/c_api_common.cc c_api/c_api_error.cc c_api/c_api_error.h - data/data.cc + data.cc logging.cc typeinfo.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_common.h diff --git a/src/data/data.cc b/src/data.cc similarity index 100% rename from src/data/data.cc rename to src/data.cc From d2d4e3fdc82bb62bd13aaa786e7d057821b3d21e Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 20:16:55 +0000 Subject: [PATCH 25/38] Fix build on MSVC --- include/treelite/compiler.h | 2 +- include/treelite/frontend.h | 2 +- src/c_api/c_api.cc | 1 + src/compiler/ast/builder.h | 2 +- src/predictor/predictor.cc | 2 +- 5 files changed, 5 insertions(+), 4 deletions(-) diff --git a/include/treelite/compiler.h b/include/treelite/compiler.h index a5dae46f..c4157746 100644 --- a/include/treelite/compiler.h +++ b/include/treelite/compiler.h @@ -17,7 +17,7 @@ namespace treelite { -struct Model; // forward declaration +class Model; // forward declaration namespace compiler { diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 50326c79..822cbe87 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -14,7 +14,7 @@ namespace treelite { -struct Model; // forward declaration +class Model; // forward declaration namespace frontend { diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1b2cc36b..4dfe98df 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index 7552c6bd..f5e5e52c 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -70,7 +70,7 @@ class ASTBuilder { } private: - friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*, ASTBuilder*); + friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*, ASTBuilder*); template NodeType* AddNode(ASTNode* parent, Args&& ...args) { diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index 9fa783d5..dbb54910 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -168,7 +168,7 @@ SharedLibrary::~SharedLibrary() { void SharedLibrary::Load(const char* libpath) { #ifdef _WIN32 - HMODULE handle = LoadLibraryA(name); + HMODULE handle = LoadLibraryA(libpath); #else void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL); #endif From c1119ffbfec5f0bc6d5a9f7cea3427005731f433 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 20:24:26 +0000 Subject: [PATCH 26/38] Define _CRT_SECURE_NO_WARNINGS to remove unneeded warnings in MSVC --- cmake/ExternalLibs.cmake | 6 ++++++ src/CMakeLists.txt | 1 + tests/cpp/CMakeLists.txt | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/cmake/ExternalLibs.cmake b/cmake/ExternalLibs.cmake index b190f784..37df92ba 100644 --- a/cmake/ExternalLibs.cmake +++ b/cmake/ExternalLibs.cmake @@ -6,6 +6,12 @@ FetchContent_Declare( GIT_TAG v0.4 ) FetchContent_MakeAvailable(dmlccore) +target_compile_options(dmlc PRIVATE + -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) +if (TARGET dmlc_unit_tests) + target_compile_options(dmlc_unit_tests PRIVATE + -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) +endif (TARGET dmlc_unit_tests) FetchContent_Declare( fmtlib diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6be940fc..c307eb95 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,7 @@ foreach(lib objtreelite objtreelite_runtime objtreelite_common) if(MSVC) target_compile_options(${lib} PRIVATE /MP) target_compile_definitions(${lib} PRIVATE -DNOMINMAX) + target_compile_options(${lib} PRIVATE /utf-8 -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) else() target_compile_options(${lib} PRIVATE -funroll-loops) endif() diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 5a8d5ab4..0dfe53b4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -7,6 +7,11 @@ target_link_libraries(treelite_cpp_test PRIVATE objtreelite objtreelite_runtime objtreelite_common GTest::GTest) set_output_directory(treelite_cpp_test ${PROJECT_BINARY_DIR}) +if(MSVC) + target_compile_options(treelite_cpp_test PRIVATE + /utf-8 -D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE) +endif() + if(TEST_COVERAGE) if(MSVC) message(FATAL_ERROR "Test coverage not available on Windows") From 0056ebe7e8ff1480927ac65035924bb7c1648b05 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 20:55:56 +0000 Subject: [PATCH 27/38] Fix MSVC warnings --- include/treelite/tree_impl.h | 2 +- src/CMakeLists.txt | 1 - src/annotator.cc | 2 +- src/compiler/ast/build.cc | 2 +- src/compiler/ast/builder.cc | 21 --------------------- src/compiler/ast/quantize.cc | 1 - src/compiler/ast_native.cc | 6 +++--- src/compiler/common/format_util.h | 2 +- src/frontend/builder.cc | 18 ++++++++++++------ src/frontend/lightgbm.cc | 3 +-- src/predictor/predictor.cc | 4 ++-- src/predictor/thread_pool/thread_pool.h | 5 +++-- 12 files changed, 25 insertions(+), 42 deletions(-) delete mode 100644 src/compiler/ast/builder.cc diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index c5fdacd1..3cf17669 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -473,7 +473,7 @@ Tree::Init() { left_categories_offset_.Resize(2, 0); nodes_.Resize(1); nodes_[0].Init(); - SetLeaf(0, 0.0f); + SetLeaf(0, static_cast(0)); } template diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c307eb95..b91342e5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -60,7 +60,6 @@ target_sources(objtreelite compiler/ast/ast.h compiler/ast/build.cc compiler/ast/builder.h - compiler/ast/builder.cc compiler/ast/dump.cc compiler/ast/fold_code.cc compiler/ast/is_categorical_array.cc diff --git a/src/annotator.cc b/src/annotator.cc index 1c7fa7ba..c90953a4 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -68,7 +68,7 @@ inline void ComputeBranchLoopImpl( CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); const size_t num_col = dmat->num_col; - const ThresholdType missing_value = dmat->missing_value; + const ElementType missing_value = dmat->missing_value; const bool nan_missing = treelite::math::CheckNAN(missing_value); const auto rbegin_i = static_cast(rbegin); const auto rend_i = static_cast(rend); diff --git a/src/compiler/ast/build.cc b/src/compiler/ast/build.cc index dc3808c7..0f007740 100644 --- a/src/compiler/ast/build.cc +++ b/src/compiler/ast/build.cc @@ -22,7 +22,7 @@ ASTBuilder::BuildAST( this->main_node = AddNode(nullptr, model.param.global_bias, model.random_forest_flag, - model.trees.size(), + static_cast(model.trees.size()), model.num_feature); ASTNode* ac = AddNode(this->main_node); this->main_node->children.push_back(ac); diff --git a/src/compiler/ast/builder.cc b/src/compiler/ast/builder.cc deleted file mode 100644 index 45e19374..00000000 --- a/src/compiler/ast/builder.cc +++ /dev/null @@ -1,21 +0,0 @@ -/*! - * Copyright (c) 2017-2020 by Contributors - * \file builder.cc - * \brief Explicit template specializations for the ASTBuilder class - * \author Hyunsu Cho - */ - -#include "./builder.h" - -namespace treelite { -namespace compiler { - -// Explicit template specializations -// (https://docs.microsoft.com/en-us/cpp/cpp/source-code-organization-cpp-templates) -template class ASTBuilder; -template class ASTBuilder; -template class ASTBuilder; -template class ASTBuilder; - -} // namespace compiler -} // namespace treelite diff --git a/src/compiler/ast/quantize.cc b/src/compiler/ast/quantize.cc index cb4c9538..7be0c4ee 100644 --- a/src/compiler/ast/quantize.cc +++ b/src/compiler/ast/quantize.cc @@ -17,7 +17,6 @@ template static void scan_thresholds(ASTNode* node, std::vector>* cut_pts) { NumericalConditionNode* num_cond; - CategoricalConditionNode* cat_cond; if ( (num_cond = dynamic_cast*>(node)) ) { CHECK(!num_cond->quantized) << "should not be already quantized"; const ThresholdType threshold = num_cond->threshold.float_val; diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index 256f63f1..fdd29e05 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -306,8 +306,8 @@ class ASTNativeCompiler : public Compiler { condition_with_na_check = ExtractCategoricalCondition(t2); } if (node->children[0]->data_count && node->children[1]->data_count) { - const int left_freq = node->children[0]->data_count.value(); - const int right_freq = node->children[1]->data_count.value(); + const size_t left_freq = node->children[0]->data_count.value(); + const size_t right_freq = node->children[1]->data_count.value(); condition_with_na_check = fmt::format(" {keyword}( {condition} ) ", "keyword"_a = ((left_freq > right_freq) ? "LIKELY" : "UNLIKELY"), @@ -333,7 +333,7 @@ class ASTNativeCompiler : public Compiler { template void HandleTUNode(const TranslationUnitNode* node, const std::string& dest, - int indent) { + size_t indent) { const int unit_id = node->unit_id; const std::string new_file = fmt::format("tu{}.c", unit_id); const std::string leaf_output_type diff --git a/src/compiler/common/format_util.h b/src/compiler/common/format_util.h index 45b12b61..c3270245 100644 --- a/src/compiler/common/format_util.h +++ b/src/compiler/common/format_util.h @@ -66,7 +66,7 @@ class ArrayFormatter { */ ArrayFormatter(size_t text_width, size_t indent, char delimiter = ',') : oss_(), text_width_(text_width), indent_(indent), delimiter_(delimiter), - default_precision_(oss_.precision()), line_length_(indent), + default_precision_(static_cast(oss_.precision())), line_length_(indent), is_empty_(true) {} /*! diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index 2d484189..c0b75b87 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -465,9 +465,12 @@ ModelBuilderImpl::CommitModelImpl(ModelImpl* out_ CHECK(node->left_child->parent == node) << "CommitModel: left child has wrong parent"; CHECK(node->right_child->parent == node) << "CommitModel: right child has wrong parent"; tree.AddChilds(nid); - node->threshold.Dispatch([&tree, nid, node](const auto& threshold) { - tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op); - }); + CHECK(node->threshold.GetValueType() == InferTypeInfoOf()) + << "CommitModel: The specified threshold has incorrect type. Expected: " + << TypeInfoToString(InferTypeInfoOf()) + << " Given: " << TypeInfoToString(node->threshold.GetValueType()); + ThresholdType threshold = node->threshold.Get(); + tree.SetNumericalSplit(nid, node->feature_id, threshold, node->default_left, node->op); Q.push({node->left_child, tree.LeftChild(nid)}); Q.push({node->right_child, tree.RightChild(nid)}); } else if (node->status == NodeDraft::Status::kCategoricalTest) { @@ -497,9 +500,12 @@ ModelBuilderImpl::CommitModelImpl(ModelImpl* out_ << "CommitModel: Inconsistent use of leaf vector: if one leaf node does not use a leaf " << "vector, *no other* leaf node can use a leaf vector"; flag_leaf_vector = 0; // now no leaf can use leaf vector - node->leaf_value.Dispatch([&tree, nid](const auto& leaf_value) { - tree.SetLeaf(nid, leaf_value); - }); + CHECK(node->leaf_value.GetValueType() == InferTypeInfoOf()) + << "CommitModel: The specified leaf value has incorrect type. Expected: " + << TypeInfoToString(InferTypeInfoOf()) + << " Given: " << TypeInfoToString(node->leaf_value.GetValueType()); + LeafOutputType leaf_value = node->leaf_value.Get(); + tree.SetLeaf(nid, leaf_value); } } } diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index 145a8a3b..06cbbe89 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -350,8 +350,7 @@ inline std::unique_ptr ParseStream(dmlc::Stream* fi) { CHECK(it != dict.end() && !it->second.empty()) << "Ill-formed LightGBM model file: need cat_threshold"; tree.cat_threshold - = TextToArray(it->second, - tree.cat_boundaries.back()); + = TextToArray(it->second, static_cast(tree.cat_boundaries.back())); } it = dict.find("split_feature"); diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index dbb54910..e4bbe160 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -60,7 +60,7 @@ inline size_t PredLoop(const treelite::CSRDMatrixImpl* dmat, int nu const size_t ibegin = row_ptr[rid]; const size_t iend = row_ptr[rid + 1]; for (size_t i = ibegin; i < iend; ++i) { - inst[col_ind[i]].fvalue = data[i]; + inst[col_ind[i]].fvalue = static_cast(data[i]); } total_output_size += func(rid, &inst[0], out_pred); for (size_t i = ibegin; i < iend; ++i) { @@ -90,7 +90,7 @@ inline size_t PredLoop(const treelite::DenseDMatrixImpl* dmat, int CHECK(nan_missing) << "The missing_value argument must be set to NaN if there is any NaN in the matrix."; } else if (nan_missing || row[j] != missing_value) { - inst[j].fvalue = row[j]; + inst[j].fvalue = static_cast(row[j]); } } total_output_size += func(rid, &inst[0], out_pred); diff --git a/src/predictor/thread_pool/thread_pool.h b/src/predictor/thread_pool/thread_pool.h index cf56b8aa..16f07c1e 100644 --- a/src/predictor/thread_pool/thread_pool.h +++ b/src/predictor/thread_pool/thread_pool.h @@ -29,7 +29,8 @@ class ThreadPool { ThreadPool(int num_worker, const TaskContext* context, TaskFunc task) : num_worker_(num_worker), context_(context), task_(task) { - CHECK(num_worker_ >= 0 && num_worker_ < std::thread::hardware_concurrency()) + CHECK(num_worker_ >= 0 + && static_cast(num_worker_) < std::thread::hardware_concurrency()) << "Number of worker threads must be between 0 and " << (std::thread::hardware_concurrency() - 1); for (int i = 0; i < num_worker_; ++i) { @@ -78,7 +79,7 @@ class ThreadPool { SetThreadAffinityMask(GetCurrentThread(), 0x1); for (int i = 0; i < num_worker_; ++i) { const int core_id = i + 1; - SetThreadAffinityMask(thread_[i].native_handle(), (1 << core_id)); + SetThreadAffinityMask(thread_[i].native_handle(), (1ULL << core_id)); } #elif defined(__APPLE__) && defined(__MACH__) #include From 40225be7610058fa3d0f8103514d98553394688a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 21:30:35 +0000 Subject: [PATCH 28/38] In CMake pkg test, specify x64 for MSVC --- tests/python/test_basic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 967c6f71..12481964 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -105,7 +105,8 @@ def test_srcpkg_cmake(tmpdir, dataset): # pylint: disable=R0914 build_dir = os.path.join(tmpdir, dataset_db[dataset].libname, 'build') os.mkdir(build_dir) nproc = os.cpu_count() - subprocess.check_call(['cmake', '..'], cwd=build_dir) + win_opts = ['-A', 'x64'] if os_platform() == 'windows' else [] + subprocess.check_call(['cmake', '..'] + win_opts, cwd=build_dir) subprocess.check_call(['cmake', '--build', '.', '--config', 'Release', '--parallel', str(nproc)], cwd=build_dir) From 53bdf642be574889e83578b84d74eab8d529281f Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 14:34:18 -0700 Subject: [PATCH 29/38] Fix lint --- src/compiler/ast/builder.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index f5e5e52c..8531c55d 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -70,7 +70,8 @@ class ASTBuilder { } private: - friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*, ASTBuilder*); + friend bool treelite::compiler::fold_code<>(ASTNode*, CodeFoldingContext*, + ASTBuilder*); template NodeType* AddNode(ASTNode* parent, Args&& ...args) { From 01a6a9c899d99202c925f7847a1f8a0c48ab442a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 21:03:20 -0700 Subject: [PATCH 30/38] Fix Java/Scala runtime --- {src/c_api => include/treelite}/c_api_error.h | 0 runtime/java/CMakeLists.txt | 2 +- runtime/java/treelite4j/pom.xml | 5 + .../java/ml/dmlc/treelite4j/java/DMatrix.java | 136 ++++++ ...{BatchBuilder.java => DMatrixBuilder.java} | 36 +- .../ml/dmlc/treelite4j/java/DenseBatch.java | 62 --- .../ml/dmlc/treelite4j/java/Predictor.java | 237 +++++----- .../ml/dmlc/treelite4j/java/SparseBatch.java | 67 --- .../ml/dmlc/treelite4j/java/TreeliteJNI.java | 79 ++-- .../scala/ml/dmlc/treelite4j/DataPoint.scala | 7 +- .../ml/dmlc/treelite4j/DataPointFloat64.scala | 14 + .../ml/dmlc/treelite4j/scala/Predictor.scala | 25 +- .../scala/spark/TreeLiteModel.scala | 16 +- .../java/treelite4j/src/native/treelite4j.cpp | 436 ++++++++++-------- .../java/treelite4j/src/native/treelite4j.h | 101 ++-- .../ml/dmlc/treelite4j/java/DMatrixTest.java | 87 ++++ .../dmlc/treelite4j/java/DenseBatchTest.java | 32 -- .../dmlc/treelite4j/java/PredictorTest.java | 66 +-- .../dmlc/treelite4j/java/SparseBatchTest.java | 74 --- .../resources/mushroom_example/mushroom.c | 8 + .../dmlc/treelite4j/scala/PredictorTest.scala | 38 +- .../treelite4j/scala/TreeLiteModelTest.scala | 4 +- src/CMakeLists.txt | 2 +- src/c_api/c_api.cc | 2 +- src/c_api/c_api_common.cc | 2 +- src/c_api/c_api_error.cc | 2 +- src/c_api/c_api_runtime.cc | 2 +- 27 files changed, 782 insertions(+), 760 deletions(-) rename {src/c_api => include/treelite}/c_api_error.h (100%) create mode 100644 runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrix.java rename runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/{BatchBuilder.java => DMatrixBuilder.java} (79%) delete mode 100644 runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DenseBatch.java delete mode 100644 runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/SparseBatch.java create mode 100644 runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala create mode 100644 runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java delete mode 100644 runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DenseBatchTest.java delete mode 100644 runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/SparseBatchTest.java diff --git a/src/c_api/c_api_error.h b/include/treelite/c_api_error.h similarity index 100% rename from src/c_api/c_api_error.h rename to include/treelite/c_api_error.h diff --git a/runtime/java/CMakeLists.txt b/runtime/java/CMakeLists.txt index bb3d6dbd..f656ba19 100644 --- a/runtime/java/CMakeLists.txt +++ b/runtime/java/CMakeLists.txt @@ -8,5 +8,5 @@ target_include_directories(treelite4j PUBLIC ${JNI_INCLUDE_DIRS}) set_target_properties(treelite4j PROPERTIES POSITION_INDEPENDENT_CODE ON - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON) diff --git a/runtime/java/treelite4j/pom.xml b/runtime/java/treelite4j/pom.xml index f0bc79c0..fd799262 100644 --- a/runtime/java/treelite4j/pom.xml +++ b/runtime/java/treelite4j/pom.xml @@ -83,6 +83,11 @@ ${spark.version} provided + + org.nd4j + nd4j-native-platform + 1.0.0-beta7 + diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrix.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrix.java new file mode 100644 index 00000000..4683fa68 --- /dev/null +++ b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrix.java @@ -0,0 +1,136 @@ +package ml.dmlc.treelite4j.java; + +import org.apache.commons.lang.ArrayUtils; + +import java.util.List; + +/** + * An opaque data matrix class. The actual object is stored in the C++ object handle. + * @author Hyunsu Cho + */ +public class DMatrix { + private long num_row, num_col, num_elem; // dimensions of the data matrix + private long handle; // handle to C++ DMatrix object + + /** + * Create a data matrix representing a 2D sparse matrix + * @param data nonzero (non-missing) entries, float32 type + * @param col_ind corresponding column indices, should be of same length as ``data`` + * @param row_ptr offsets to define each instance, should be of length ``[num_row]+1`` + * @param num_row number of rows (data points) in the matrix + * @param num_col number of columns (features) in the matrix + * @throws TreeliteError error during matrix construction + */ + public DMatrix(float[] data, int[] col_ind, long[] row_ptr, long num_row, long num_col) + throws TreeliteError { + long[] out = new long[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreeliteDMatrixCreateFromCSRWithFloat32In( + data, col_ind, row_ptr, num_row, num_col, out)); + this.handle = out[0]; + setDims(); + } + + public DMatrix(double[] data, int[] col_ind, long[] row_ptr, long num_row, long num_col) + throws TreeliteError { + long[] out = new long[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreeliteDMatrixCreateFromCSRWithFloat64In( + data, col_ind, row_ptr, num_row, num_col, out)); + this.handle = out[0]; + setDims(); + } + + /** + * Create a data matrix representing a 2D dense matrix + * @param data array of entries, should be of length ``[num_row]*[num_col]`` and of float32 type + * @param missing_value floating-point value representing a missing value; + * usually set of ``Float.NaN``. + * @param num_row number of rows (data instances) in the matrix + * @param num_col number of columns (features) in the matrix + * @throws TreeliteError error during matrix construction + */ + public DMatrix(float[] data, float missing_value, long num_row, long num_col) + throws TreeliteError { + long[] out = new long[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreeliteDMatrixCreateFromMatWithFloat32In( + data, num_row, num_col, missing_value, out)); + this.handle = out[0]; + setDims(); + } + + /** + * Create a data matrix representing a 2D dense matrix (float64 type) + * @param data array of entries, should be of length ``[num_row]*[num_col]`` and of float64 type + * @param missing_value floating-point value representing a missing value; + * usually set of ``Double.NaN``. + * @param num_row number of rows (data instances) in the matrix + * @param num_col number of columns (features) in the matrix + * @throws TreeliteError error during matrix construction + */ + public DMatrix(double[] data, double missing_value, long num_row, long num_col) + throws TreeliteError { + long[] out = new long[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreeliteDMatrixCreateFromMatWithFloat64In( + data, num_row, num_col, missing_value, out)); + this.handle = out[0]; + setDims(); + } + + private void setDims() throws TreeliteError { + long[] out_num_row = new long[1]; + long[] out_num_col = new long[1]; + long[] out_num_elem = new long[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreeliteDMatrixGetDimension( + this.handle, out_num_row, out_num_col, out_num_elem)); + this.num_row = out_num_row[0]; + this.num_col = out_num_col[0]; + this.num_elem = out_num_elem[0]; + } + + /** + * Get the underlying native handle + * @return Integer representing memory address + */ + public long getHandle() { + return this.handle; + } + + /** + * Get the number of rows in the matrix + * @return Number of rows in the matrix + */ + public long getNumRow() { + return this.num_row; + } + + /** + * Get the number of columns in the matrix + * @return Number of columns in the matrix + */ + public long getNumCol() { + return this.num_col; + } + + /** + * Get the number of elements in the matrix + * @return Number of elements in the matrix + */ + public long getNumElements() { + return this.num_elem; + } + + @Override + protected void finalize() throws Throwable { + super.finalize(); + dispose(); + } + + /** + * Destructor, to be called when the object is garbage collected + */ + public synchronized void dispose() { + if (this.handle != 0L) { + TreeliteJNI.TreeliteDMatrixFree(this.handle); + this.handle = 0; + } + } +} diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/BatchBuilder.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java similarity index 79% rename from runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/BatchBuilder.java rename to runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java index d74efab3..b9460772 100644 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/BatchBuilder.java +++ b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java @@ -13,20 +13,20 @@ import java.util.List; /** - * Collection of utility functions to create batch objects + * Collection of utility functions to create data matrices * * @author Hyunsu Cho */ -public class BatchBuilder { +public class DMatrixBuilder { /** - * Assemble a sparse batch from a list of data points + * Build a sparse (CSR layout) matrix from a list of data points (float32) * * @param dIter Iterator of data points - * @return Created sparse batch + * @return Created sparse data matrix * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static SparseBatch CreateSparseBatch(Iterator dIter) + public static DMatrix createSparseCSRDMatrixFloat32(Iterator dIter) throws TreeliteError, IOException { ArrayList data = new ArrayList<>(); ArrayList col_ind = new ArrayList<>(); @@ -37,8 +37,8 @@ public static SparseBatch CreateSparseBatch(Iterator dIter) while (dIter.hasNext()) { ++num_row; DataPoint inst = dIter.next(); - int nnz = 0; // count number of nonzero feature values for current row - for (float e : inst.values()) { + int nnz = 0; // count number of nonzero feature values for current row + for (Float e : inst.values()) { data.add(e); ++nnz; } @@ -57,24 +57,22 @@ public static SparseBatch CreateSparseBatch(Iterator dIter) } row_ptr.add(row_ptr.get(row_ptr.size() - 1) + (long) nnz); } - float[] data_arr - = ArrayUtils.toPrimitive(data.toArray(new Float[0])); - int[] col_ind_arr - = ArrayUtils.toPrimitive(col_ind.toArray(new Integer[0])); - long[] row_ptr_arr - = ArrayUtils.toPrimitive(row_ptr.toArray(new Long[0])); - return new SparseBatch(data_arr, col_ind_arr, row_ptr_arr, num_row, num_col); + float[] data_arr = ArrayUtils.toPrimitive(data.toArray(new Float[0])); + int[] col_ind_arr = ArrayUtils.toPrimitive(col_ind.toArray(new Integer[0])); + long[] row_ptr_arr = ArrayUtils.toPrimitive(row_ptr.toArray(new Long[0])); + return new DMatrix(data_arr, col_ind_arr, row_ptr_arr, num_row, num_col); } /** - * Assemble a dense batch from a list of data points + * Assemble a dense matrix from a list of data points (float32) * * @param dIter Iterator of data points * @return Created dense batch * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static DenseBatch CreateDenseBatch(Iterator dIter) + public static DMatrix createDenseDMatrixFloat32( + Iterator dIter) throws TreeliteError, IOException { int num_row = 0; int num_col = 0; @@ -117,7 +115,7 @@ public static DenseBatch CreateDenseBatch(Iterator dIter) } assert row_id == num_row; - return new DenseBatch(data, Float.NaN, num_row, num_col); + return new DMatrix(data, Float.NaN, num_row, num_col); } /** @@ -128,11 +126,11 @@ public static DenseBatch CreateDenseBatch(Iterator dIter) * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static List LoadDatasetFromLibSVM(String filename) + public static List LoadDatasetFromLibSVMFloat32(String filename) throws TreeliteError, IOException { File file = new File(filename); LineIterator it = FileUtils.lineIterator(file, "UTF-8"); - ArrayList dmat = new ArrayList(); + ArrayList dmat = new ArrayList<>(); try { while (it.hasNext()) { String line = it.nextLine(); diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DenseBatch.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DenseBatch.java deleted file mode 100644 index 33057ae5..00000000 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DenseBatch.java +++ /dev/null @@ -1,62 +0,0 @@ -package ml.dmlc.treelite4j.java; - -/** - * 2D dense batch, laid out in row-major layout - * @author Hyunsu Cho - */ -public class DenseBatch { - private float[] data; - private float missing_value; - private int num_row; - private int num_col; - - private long handle; - - /** - * Create a dense batch representing a 2D dense matrix - * @param data array of entries, should be of length ``[num_row]*[num_col]`` - * @param missing_value floating-point value representing a missing value; - * usually set of ``Float.NaN``. - * @param num_row number of rows (data instances) in the matrix - * @param num_row number of columns (features) in the matrix - * @return Created dense batch - * @throws TreeliteError - */ - public DenseBatch( - float[] data, float missing_value, int num_row, int num_col) - throws TreeliteError { - this.data = data; - this.missing_value = missing_value; - this.num_row = num_row; - this.num_col = num_col; - - long[] out = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreeliteAssembleDenseBatch( - this.data, this.missing_value, this.num_row, this.num_col, out)); - handle = out[0]; - } - - /** - * Get the underlying native handle - * @return Integer representing memory address - */ - public long getHandle() { - return this.handle; - } - - @Override - protected void finalize() throws Throwable { - super.finalize(); - dispose(); - } - - /** - * Destructor, to be called when the object is garbage collected - */ - public synchronized void dispose() { - if (handle != 0L) { - TreeliteJNI.TreeliteDeleteDenseBatch(handle, this.data); - handle = 0; - } - } -} diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/Predictor.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/Predictor.java index c2a1561f..c55b4fd8 100644 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/Predictor.java +++ b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/Predictor.java @@ -7,8 +7,9 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; -import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; import java.io.Serializable; @@ -47,11 +48,9 @@ public class Predictor implements Serializable, KryoSerializable { * CPU cores. Setting ``nthread=1`` indicates that the main * thread should be exclusively used. * @param verbose Whether to print extra diagnostic messages - * @return Created Predictor - * @throws TreeliteError + * @throws TreeliteError error during loading the shared lib */ - public Predictor( - String libpath, int nthread, boolean verbose) throws TreeliteError { + public Predictor(String libpath, int nthread, boolean verbose) throws TreeliteError { this.num_thread = nthread; this.verbose = verbose; initNativeLibrary(libpath); @@ -170,122 +169,92 @@ public float GetGlobalBias() { } /** - * Perform single-instance prediction. Prediction is run by the calling thread. - * - * @param inst array of data entires comprising the instance - * @param pred_margin whether to predict a probability or a raw margin score - * @return Resulting predictions, of dimension ``[num_output_group]`` - */ - public float[] predict(Data[] inst, boolean pred_margin) - throws TreeliteError, IOException { - - assert inst.length > 0; - - // query result size - long[] out = new long[1]; - TreeliteJNI.checkCall( - TreeliteJNI.TreelitePredictorQueryResultSizeSingleInst(this.handle, out)); - int result_size = (int) out[0]; - float[] out_result = new float[result_size]; - - // serialize instance as byte array - ByteArrayOutputStream os = new ByteArrayOutputStream(); - for (int i = 0; i < inst.length; ++i) { - inst[i].write(os); - } - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictInst( - this.handle, os.toByteArray(), pred_margin, out_result, out)); - int actual_result_size = (int) out[0]; - float[] result = new float[actual_result_size]; - for (int i = 0; i < actual_result_size; ++i) { - result[i] = out_result[i]; - } - return result; - } - - /** - * Perform batch prediction with a 2D sparse data matrix. Worker threads + * Perform batch prediction with a 2D data matrix. Worker threads * will internally divide up work for batch prediction. **Note that this - * function will be blocked by mutex when worker_thread > 1.** In order to use - * multiple threads to process multiple prediction requests simultaneously, - * use :java:meth:`Predictor.predict(Data[], boolean)` instead or keep worker_thread = 1. + * function will be blocked by mutex when worker_thread > 1.** * - * @param batch a :java:ref:`SparseBatch`, representing a slice of a 2D - * sparse matrix + * @param batch a data matrix of type :java:ref:`DMatrix` * @param verbose whether to print extra diagnostic messages * @param pred_margin whether to predict probabilities or raw margin scores * @return Resulting predictions, of dimension ``[num_row]*[num_output_group]`` */ - public float[][] predict( - SparseBatch batch, boolean verbose, boolean pred_margin) - throws TreeliteError { + public INDArray predict(DMatrix batch, boolean verbose, boolean pred_margin) + throws TreeliteError { long[] out = new long[1]; TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorQueryResultSize( - this.handle, batch.getHandle(), true, out)); - int result_size = (int) out[0]; - float[] out_result = new float[result_size]; - if (num_thread == 1) { - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatch( - this.handle, batch.getHandle(), true, verbose, pred_margin, - out_result, out)); - } else { - synchronized (this) { - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatch( - this.handle, batch.getHandle(), true, verbose, pred_margin, + this.handle, batch.getHandle(), out)); + int result_size = (int)out[0]; + + String[] s_out = new String[1]; + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorQueryLeafOutputType(this.handle, s_out)); + String leaf_output_type = s_out[0]; + switch (leaf_output_type) { + case "float32": { + float[] out_result = new float[result_size]; + if (num_thread == 1) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithFloat32Out( + this.handle, batch.getHandle(), verbose, pred_margin, out_result, out)); + } else { + synchronized (this) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithFloat32Out( + this.handle, batch.getHandle(), verbose, pred_margin, out_result, out)); + } + } + int actual_result_size = (int) out[0]; + return reshape(out_result, actual_result_size, this.num_output_group); } - } - int actual_result_size = (int) out[0]; - return reshape(out_result, actual_result_size, this.num_output_group); - } - - /** - * Perform batch prediction with a 2D dense data matrix. Worker threads - * will internally divide up work for batch prediction. **Note that this - * function will be blocked by mutex when worker_thread > 1.** In order to use - * multiple threads to process multiple prediction requests simultaneously, - * use :java:meth:`Predictor.predict(Data[], boolean)` instead or keep worker_thread = 1. - * - * @param batch a :java:ref:`DenseBatch`, representing a slice of a 2D dense - * matrix - * @param verbose whether to print extra diagnostic messages - * @param pred_margin whether to predict probabilities or raw margin scores - * @return Resulting predictions, of dimension ``[num_row]*[num_output_group]`` - */ - public float[][] predict( - DenseBatch batch, boolean verbose, boolean pred_margin) - throws TreeliteError { - long[] out = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorQueryResultSize( - this.handle, batch.getHandle(), false, out)); - int result_size = (int) out[0]; - float[] out_result = new float[result_size]; - if (num_thread == 1) { - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatch( - this.handle, batch.getHandle(), false, verbose, pred_margin, - out_result, out)); - } else { - synchronized (this) { - TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatch( - this.handle, batch.getHandle(), false, verbose, pred_margin, + case "float64": { + double[] out_result = new double[result_size]; + if (num_thread == 1) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithFloat64Out( + this.handle, batch.getHandle(), verbose, pred_margin, out_result, out)); + } else { + synchronized (this) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithFloat64Out( + this.handle, batch.getHandle(), verbose, pred_margin, out_result, out)); + } + } + int actual_result_size = (int) out[0]; + return reshape(out_result, actual_result_size, this.num_output_group); } + case "uint32": { + int[] out_result = new int[result_size]; + if (num_thread == 1) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithUInt32Out( + this.handle, batch.getHandle(), verbose, pred_margin, out_result, out)); + } else { + synchronized (this) { + TreeliteJNI.checkCall(TreeliteJNI.TreelitePredictorPredictBatchWithUInt32Out( + this.handle, batch.getHandle(), verbose, pred_margin, + out_result, out)); + } + } + int actual_result_size = (int) out[0]; + return reshape(out_result, actual_result_size, this.num_output_group); + } + default: + throw new TreeliteError("Unknown leaf output type: " + leaf_output_type); } - int actual_result_size = (int) out[0]; - return reshape(out_result, actual_result_size, this.num_output_group); } - private float[][] reshape(float[] array, int rend, int num_col) { + private INDArray reshape(float[] array, int rend, int num_col) { assert rend <= array.length; assert rend % num_col == 0; - float[][] res; - res = new float[rend / num_col][num_col]; - for (int i = 0; i < rend; ++i) { - int r = i / num_col; - int c = i % num_col; - res[r][c] = array[i]; - } - return res; + return Nd4j.create(array, new int[]{rend / num_col, num_col}, 'c'); + } + + private INDArray reshape(double[] array, int rend, int num_col) { + assert rend <= array.length; + assert rend % num_col == 0; + return Nd4j.create(array, new int[]{rend / num_col, num_col}, 'c'); + } + + private INDArray reshape(int[] array, int rend, int num_col) { + assert rend <= array.length; + assert rend % num_col == 0; + return Nd4j.create(array, new int[]{rend / num_col, num_col}, 'c'); } @Override @@ -338,40 +307,40 @@ private void writeObject(java.io.ObjectOutputStream out) throws IOException { @Override public void write(Kryo kryo, Output out) { - out.writeInt(this.num_thread); - out.writeBoolean(this.verbose); - byte[] libext = this.libext.getBytes(); - out.writeShort(libext.length); - out.write(libext); - try { - byte[] lib_data = Files.readAllBytes(Paths.get(libpath)); - out.writeInt(lib_data.length); - out.write(lib_data); - } catch (IOException e) { - logger.error("Error while loading TreeLite dynamic shared library!"); - } + out.writeInt(this.num_thread); + out.writeBoolean(this.verbose); + byte[] libext = this.libext.getBytes(); + out.writeShort(libext.length); + out.write(libext); + try { + byte[] lib_data = Files.readAllBytes(Paths.get(libpath)); + out.writeInt(lib_data.length); + out.write(lib_data); + } catch (IOException e) { + logger.error("Error while loading TreeLite dynamic shared library!"); + } } @Override public void read(Kryo kryo, Input in) { - this.num_thread = in.readInt(); - this.verbose = in.readBoolean(); - byte[] libext = new byte[in.readShort()]; - in.read(libext); - File libpath = null; - try { - libpath = File.createTempFile("TreeLite_", new String(libext)); - byte[] lib_data = new byte[in.readInt()]; - in.read(lib_data); - FileUtils.writeByteArrayToFile(libpath, lib_data); - initNativeLibrary(libpath.getAbsolutePath()); - } catch (Exception ex) { - ex.printStackTrace(); - logger.error("Error while loading TreeLite dynamic shared library!"); - } finally { - if (libpath != null) { - libpath.delete(); - } + this.num_thread = in.readInt(); + this.verbose = in.readBoolean(); + byte[] libext = new byte[in.readShort()]; + in.read(libext); + File libpath = null; + try { + libpath = File.createTempFile("TreeLite_", new String(libext)); + byte[] lib_data = new byte[in.readInt()]; + in.read(lib_data); + FileUtils.writeByteArrayToFile(libpath, lib_data); + initNativeLibrary(libpath.getAbsolutePath()); + } catch (Exception ex) { + ex.printStackTrace(); + logger.error("Error while loading TreeLite dynamic shared library!"); + } finally { + if (libpath != null) { + libpath.delete(); } + } } } diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/SparseBatch.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/SparseBatch.java deleted file mode 100644 index 243636f4..00000000 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/SparseBatch.java +++ /dev/null @@ -1,67 +0,0 @@ -package ml.dmlc.treelite4j.java; - -/** - * 2D sparse batch, laid out in CSR (Compressed Sparse Row) layout - * @author Hyunsu Cho - */ -public class SparseBatch { - private float[] data; - private int[] col_ind; - private long[] row_ptr; - private int num_row; - private int num_col; - - private long handle; - - /** - * Create a sparse batch representing a 2D sparse matrix - * @param data nonzero (non-missing) entries - * @param col_ind corresponding column indices, should be of same length as - * ``data`` - * @param row_ptr offsets to define each instance, should be of length - * ``[num_row]+1`` - * @param num_row number of rows (data instances) in the matrix - * @param num_row number of columns (features) in the matrix - * @return Created sparse batch - * @throws TreeliteError - */ - public SparseBatch( - float[] data, int[] col_ind, long[] row_ptr, int num_row, int num_col) - throws TreeliteError { - this.data = data; - this.col_ind = col_ind; - this.row_ptr = row_ptr; - this.num_row = num_row; - this.num_col = num_col; - - long[] out = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreeliteAssembleSparseBatch( - this.data, this.col_ind, this.row_ptr, this.num_row, this.num_col, out)); - handle = out[0]; - } - - /** - * Get the underlying native handle - * @return Integer representing memory address - */ - public long getHandle() { - return this.handle; - } - - @Override - protected void finalize() throws Throwable { - super.finalize(); - dispose(); - } - - /** - * Destructor, to be called when the object is garbage collected - */ - public synchronized void dispose() { - if (handle != 0L) { - TreeliteJNI.TreeliteDeleteSparseBatch( - handle, this.data, this.col_ind, this.row_ptr); - handle = 0; - } - } -} diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/TreeliteJNI.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/TreeliteJNI.java index a764bd2f..43b8de57 100644 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/TreeliteJNI.java +++ b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/TreeliteJNI.java @@ -29,56 +29,65 @@ static void checkCall(int ret) throws TreeliteError { } } - public final static native String TreeliteGetLastError(); + public static native String TreeliteGetLastError(); - public final static native int TreeliteAssembleSparseBatch( - float[] data, int[] col_ind, long[] row_ptr, long num_row, long num_col, - long[] out); + public static native int TreeliteDMatrixCreateFromCSRWithFloat32In( + float[] data, int[] col_ind, long[] row_ptr, long num_row, long num_col, long[] out); - public final static native int TreeliteDeleteSparseBatch( - long handle, float[] data, int[] col_ind, long[] row_ptr); + public static native int TreeliteDMatrixCreateFromCSRWithFloat64In( + double[] data, int[] col_ind, long[] row_ptr, long num_row, long num_col, long[] out); - public final static native int TreeliteAssembleDenseBatch( - float[] data, float missing_value, long num_row, long num_col, long[] out); + public static native int TreeliteDMatrixCreateFromMatWithFloat32In( + float[] data, long num_row, long num_col, float missing_value, long[] out); - public final static native int TreeliteDeleteDenseBatch( - long handle, float[] data); + public static native int TreeliteDMatrixCreateFromMatWithFloat64In( + double[] data, long num_row, long num_col, double missing_value, long[] out); - public final static native int TreeliteBatchGetDimension( - long handle, boolean batch_sparse, long[] out_num_row, long[] out_num_col); + public static native int TreeliteDMatrixGetDimension( + long handle, long[] out_num_row, long[] out_num_col, long[] out_nelem); - public final static native int TreelitePredictorLoad( - String library_path, int num_worker_thread, long[] out); + public static native int TreeliteDMatrixFree( + long handle); - public final static native int TreelitePredictorPredictBatch( - long handle, long batch, boolean batch_sparse, boolean verbose, - boolean pred_margin, float[] out_result, long[] out_result_size); + public static native int TreelitePredictorLoad( + String library_path, int num_worker_thread, long[] out); - public final static native int TreelitePredictorPredictInst( - long handle, byte[] inst, boolean pred_margin, float[] out_result, - long[] out_result_size); + public static native int TreelitePredictorPredictBatchWithFloat32Out( + long handle, long batch, boolean verbose, boolean pred_margin, float[] out_result, + long[] out_result_size); - public final static native int TreelitePredictorQueryResultSize( - long handle, long batch, boolean batch_sparse, long[] out); + public static native int TreelitePredictorPredictBatchWithFloat64Out( + long handle, long batch, boolean verbose, boolean pred_margin, double[] out_result, + long[] out_result_size); - public final static native int TreelitePredictorQueryResultSizeSingleInst( - long handle, long[] out); + public static native int TreelitePredictorPredictBatchWithUInt32Out( + long handle, long batch, boolean verbose, boolean pred_margin, int[] out_result, + long[] out_result_size); - public final static native int TreelitePredictorQueryNumOutputGroup( - long handle, long[] out); + public static native int TreelitePredictorQueryResultSize( + long handle, long batch, long[] out); - public final static native int TreelitePredictorQueryNumFeature( - long handle, long[] out); + public static native int TreelitePredictorQueryNumOutputGroup( + long handle, long[] out); - public final static native int TreelitePredictorQueryPredTransform( - long handle, String[] out); + public static native int TreelitePredictorQueryNumFeature( + long handle, long[] out); - public final static native int TreelitePredictorQuerySigmoidAlpha( - long handle, float[] out); + public static native int TreelitePredictorQueryPredTransform( + long handle, String[] out); - public final static native int TreelitePredictorQueryGlobalBias( - long handle, float[] out); + public static native int TreelitePredictorQuerySigmoidAlpha( + long handle, float[] out); - public final static native int TreelitePredictorFree(long handle); + public static native int TreelitePredictorQueryGlobalBias( + long handle, float[] out); + + public static native int TreelitePredictorQueryThresholdType( + long handle, String[] out); + + public static native int TreelitePredictorQueryLeafOutputType( + long handle, String[] out); + + public static native int TreelitePredictorFree(long handle); } diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala index e309d593..644d59ec 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala @@ -15,7 +15,7 @@ package ml.dmlc.treelite4j /** - * A data point (instance) + * A data point (instance) with float32 values * * @param indices Feature indices of this point or `null` if the data is dense * @param values Feature values of this point @@ -23,5 +23,6 @@ package ml.dmlc.treelite4j case class DataPoint( indices: Array[Int], values: Array[Float]) extends Serializable { - require(indices == null || indices.length == values.length, "indices and values must have the same number of elements") -} + require(indices == null || indices.length == values.length, + "indices and values must have the same number of elements") +} \ No newline at end of file diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala new file mode 100644 index 00000000..a91ae4ae --- /dev/null +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala @@ -0,0 +1,14 @@ +package ml.dmlc.treelite4j + +/** + * A data point (instance) with float64 values + * + * @param indices Feature indices of this point or `null` if the data is dense + * @param values Feature values of this point + */ +case class DataPointFloat64( + indices: Array[Int], + values: Array[Double]) extends Serializable { + require(indices == null || indices.length == values.length, + "indices and values must have the same number of elements") +} \ No newline at end of file diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/Predictor.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/Predictor.scala index 16a9e8f8..dc4ddbf1 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/Predictor.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/Predictor.scala @@ -2,7 +2,8 @@ package ml.dmlc.treelite4j.scala import java.io.IOException -import ml.dmlc.treelite4j.java.{Data, DenseBatch, SparseBatch, TreeliteError, Predictor => JPredictor} +import ml.dmlc.treelite4j.java.{DMatrix, Data, TreeliteError, Predictor => JPredictor} +import org.nd4j.linalg.api.ndarray.INDArray import scala.reflect.ClassTag @@ -32,26 +33,10 @@ class Predictor private[treelite4j](private[treelite4j] val pred: JPredictor) def globalBias: Float = pred.GetGlobalBias() @throws(classOf[TreeliteError]) - @throws(classOf[IOException]) - def predictInst[T <: Data : ClassTag]( - inst: Array[T], - predMargin: Boolean = false): Array[Float] = { - pred.predict(inst.asInstanceOf[Array[Data]], predMargin) - } - - @throws(classOf[TreeliteError]) - def predictSparseBatch( - batch: SparseBatch, - predMargin: Boolean = false, - verbose: Boolean = false): Array[Array[Float]] = { - pred.predict(batch, verbose, predMargin) - } - - @throws(classOf[TreeliteError]) - def predictDenseBatch( - batch: DenseBatch, + def predictBatch( + batch: DMatrix, predMargin: Boolean = false, - verbose: Boolean = false): Array[Array[Float]] = { + verbose: Boolean = false): INDArray = { pred.predict(batch, verbose, predMargin) } diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala index f60cf1f0..8bf16a61 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala @@ -1,7 +1,7 @@ package ml.dmlc.treelite4j.scala.spark import ml.dmlc.treelite4j.DataPoint -import ml.dmlc.treelite4j.java.BatchBuilder +import ml.dmlc.treelite4j.java.DMatrixBuilder import ml.dmlc.treelite4j.scala.Predictor import org.apache.spark.ml.PredictionModel import org.apache.spark.ml.linalg._ @@ -70,15 +70,15 @@ class TreeLiteModel private[spark]( } val result = batchRow.head.getAs[Vector]($(featuresCol)) match { case _: SparseVector => - val batch = BatchBuilder.CreateSparseBatch(dataPoints.asJava) - val ret = broadcastModel.value.predictSparseBatch(batch, $(predictMargin), $(verbose)) + val batch = DMatrixBuilder.createSparseCSRDMatrixFloat32(dataPoints.asJava) + val ret = broadcastModel.value.predictBatch(batch, $(predictMargin), $(verbose)) batch.dispose() - ret.map(Row.apply(_)) + ret.toFloatMatrix.map(Row.apply(_)) case _: DenseVector => - val batch = BatchBuilder.CreateDenseBatch(dataPoints.asJava) - val ret = broadcastModel.value.predictDenseBatch(batch, $(predictMargin), $(verbose)) + val batch = DMatrixBuilder.createDenseDMatrixFloat32(dataPoints.asJava) + val ret = broadcastModel.value.predictBatch(batch, $(predictMargin), $(verbose)) batch.dispose() - ret.map(Row.apply(_)) + ret.toFloatMatrix.map(Row.apply(_)) } batchRow.zip(result).map { case (origin, ret) => Row.merge(origin, ret) @@ -87,7 +87,7 @@ class TreeLiteModel private[spark]( } // append result columns to schema val schema = StructType(dataset.schema.fields ++ Seq( - StructField($(predictionCol), ArrayType(FloatType, false), nullable = false))) + StructField($(predictionCol), ArrayType(FloatType, containsNull = false), nullable = false))) dataset.sparkSession.createDataFrame(resultRDD, schema) } diff --git a/runtime/java/treelite4j/src/native/treelite4j.cpp b/runtime/java/treelite4j/src/native/treelite4j.cpp index 84887b29..06d753df 100644 --- a/runtime/java/treelite4j/src/native/treelite4j.cpp +++ b/runtime/java/treelite4j/src/native/treelite4j.cpp @@ -1,10 +1,13 @@ #include #include +#include #include +#include #include #include #include #include +#include #include #include "./treelite4j.h" @@ -13,9 +16,9 @@ namespace { // set handle void setHandle(JNIEnv* jenv, jlongArray jhandle, void* handle) { #ifdef __APPLE__ - jlong out = (long)handle; + auto out = static_cast(reinterpret_cast(handle)); #else - int64_t out = (int64_t)handle; + auto out = reinterpret_cast(handle); #endif jenv->SetLongArrayRegion(jhandle, 0, 1, &out); } @@ -29,8 +32,8 @@ void setHandle(JNIEnv* jenv, jlongArray jhandle, void* handle) { */ JNIEXPORT jstring JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteGetLastError( - JNIEnv* jenv, jclass jcls) { - jstring jresult = 0; + JNIEnv* jenv, jclass jcls) { + jstring jresult = nullptr; const char* result = TreeliteGetLastError(); if (result) { jresult = jenv->NewStringUTF(result); @@ -40,104 +43,135 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteGetLastError( /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteAssembleSparseBatch + * Method: TreeliteDMatrixCreateFromCSRWithFloat32In * Signature: ([F[I[JJJ[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteAssembleSparseBatch( - JNIEnv* jenv, jclass jcls, jfloatArray jdata, jintArray jcol_ind, - jlongArray jrow_ptr, jlong jnum_row, jlong jnum_col, jlongArray jout) { - - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); - jint* col_ind = jenv->GetIntArrayElements(jcol_ind, 0); - jlong* row_ptr = jenv->GetLongArrayElements(jrow_ptr, 0); - CSRBatchHandle out; - jint ret; - if (sizeof(size_t) == sizeof(uint64_t)) { - ret = (jint)TreeliteAssembleSparseBatch((const float*)data, - (const uint32_t*)col_ind, (const size_t*)row_ptr, - (size_t)jnum_row, (size_t)jnum_col, &out); - } else { - LOG(FATAL) << "32-bit platform not supported yet"; - } +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromCSRWithFloat32In( + JNIEnv* jenv, jclass jcls, jfloatArray jdata, jintArray jcol_ind, jlongArray jrow_ptr, + jlong jnum_row, jlong jnum_col, jlongArray jout) { + jfloat* data = jenv->GetFloatArrayElements(jdata, nullptr); + jint* col_ind = jenv->GetIntArrayElements(jcol_ind, nullptr); + jlong* row_ptr = jenv->GetLongArrayElements(jrow_ptr, nullptr); + DMatrixHandle out = nullptr; + const int ret = TreeliteDMatrixCreateFromCSR( + static_cast(data), "float32", reinterpret_cast(col_ind), + reinterpret_cast(row_ptr), static_cast(jnum_row), + static_cast(jnum_col), &out); setHandle(jenv, jout, out); + // release arrays + jenv->ReleaseFloatArrayElements(jdata, data, 0); + jenv->ReleaseIntArrayElements(jcol_ind, col_ind, 0); + jenv->ReleaseLongArrayElements(jrow_ptr, row_ptr, 0); - return ret; + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteDeleteSparseBatch - * Signature: (J[F[I[J)I + * Method: TreeliteDMatrixCreateFromCSRWithFloat64In + * Signature: ([D[I[JJJ[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDeleteSparseBatch( - JNIEnv* jenv, jclass jcls, jlong jhandle, - jfloatArray jdata, jintArray jcol_ind, jlongArray jrow_ptr) { - - treelite::CSRBatch* batch = (treelite::CSRBatch*)jhandle; - jenv->ReleaseFloatArrayElements(jdata, (jfloat*)batch->data, 0); - jenv->ReleaseIntArrayElements(jcol_ind, (jint*)batch->col_ind, 0); - jenv->ReleaseLongArrayElements(jrow_ptr, (jlong*)batch->row_ptr, 0); - return (jint)TreeliteDeleteSparseBatch((CSRBatchHandle)batch); +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromCSRWithFloat64In( + JNIEnv* jenv, jclass jcls, jdoubleArray jdata, jintArray jcol_ind, jlongArray jrow_ptr, + jlong jnum_row, jlong jnum_col, jlongArray jout) { + jdouble* data = jenv->GetDoubleArrayElements(jdata, nullptr); + jint* col_ind = jenv->GetIntArrayElements(jcol_ind, nullptr); + jlong* row_ptr = jenv->GetLongArrayElements(jrow_ptr, nullptr); + DMatrixHandle out = nullptr; + const int ret = TreeliteDMatrixCreateFromCSR( + static_cast(data), "float64", reinterpret_cast(col_ind), + reinterpret_cast(row_ptr), static_cast(jnum_row), + static_cast(jnum_col), &out); + setHandle(jenv, jout, out); + // release arrays + jenv->ReleaseDoubleArrayElements(jdata, data, 0); + jenv->ReleaseIntArrayElements(jcol_ind, col_ind, 0); + jenv->ReleaseLongArrayElements(jrow_ptr, row_ptr, 0); + + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteAssembleDenseBatch - * Signature: ([FFJJ[J)I + * Method: TreeliteDMatrixCreateFromMatWithFloat32In + * Signature: ([FJJF[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteAssembleDenseBatch( - JNIEnv* jenv, jclass jcls, jfloatArray jdata, jfloat jmissing_value, - jlong jnum_row, jlong jnum_col, jlongArray jout) { - - jfloat* data = jenv->GetFloatArrayElements(jdata, 0); - DenseBatchHandle out; - const jint ret = (jint)TreeliteAssembleDenseBatch((const float*)data, - (float)jmissing_value, (size_t)jnum_row, (size_t)jnum_col, &out); +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromMatWithFloat32In( + JNIEnv* jenv, jclass jcls, jfloatArray jdata, jlong jnum_row, jlong jnum_col, + jfloat jmissing_value, jlongArray jout) { + jfloat* data = jenv->GetFloatArrayElements(jdata, nullptr); + float missing_value = static_cast(jmissing_value); + DMatrixHandle out = nullptr; + const int ret = TreeliteDMatrixCreateFromMat( + static_cast(data), "float32", static_cast(jnum_row), + static_cast(jnum_col), &missing_value, &out); setHandle(jenv, jout, out); + // release arrays + jenv->ReleaseFloatArrayElements(jdata, data, 0); - return ret; + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteDeleteDenseBatch - * Signature: (J[F)I + * Method: TreeliteDMatrixCreateFromMatWithFloat64In + * Signature: ([DJJD[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDeleteDenseBatch( - JNIEnv* jenv, jclass jcls, jlong jhandle, jfloatArray jdata) { +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromMatWithFloat64In( + JNIEnv* jenv, jclass jcls, jdoubleArray jdata, jlong jnum_row, jlong jnum_col, + jdouble jmissing_value, jlongArray jout) { + jdouble* data = jenv->GetDoubleArrayElements(jdata, nullptr); + double missing_value = static_cast(jmissing_value); + DMatrixHandle out = nullptr; + const int ret = TreeliteDMatrixCreateFromMat( + static_cast(data), "float64", static_cast(jnum_row), + static_cast(jnum_col), &missing_value, &out); + setHandle(jenv, jout, out); + // release arrays + jenv->ReleaseDoubleArrayElements(jdata, data, 0); - treelite::DenseBatch* batch = (treelite::DenseBatch*)jhandle; - jenv->ReleaseFloatArrayElements(jdata, (jfloat*)batch->data, 0); - return (jint)TreeliteDeleteDenseBatch((DenseBatchHandle)batch); + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteBatchGetDimension - * Signature: (JZ[J[J)I + * Method: TreeliteDMatrixGetDimension + * Signature: (J[J[J[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteBatchGetDimension( - JNIEnv* jenv, jclass jcls, jlong jhandle, jboolean jbatch_sparse, - jlongArray jout_num_row, jlongArray jout_num_col) { - - size_t num_row, num_col; - const jint ret = (jint)TreeliteBatchGetDimension((void*)jhandle, - (jbatch_sparse == JNI_TRUE ? 1 : 0), &num_row, &num_col); - +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixGetDimension( + JNIEnv* jenv, jclass jcls, jlong jmat, jlongArray jout_num_row, jlongArray jout_num_col, + jlongArray jout_nelem) { + DMatrixHandle dmat = reinterpret_cast(jmat); + size_t num_row = 0, num_col = 0, num_elem = 0; + const int ret = TreeliteDMatrixGetDimension(dmat, &num_row, &num_col, &num_elem); // save dimensions - jlong* out_num_row = jenv->GetLongArrayElements(jout_num_row, 0); - jlong* out_num_col = jenv->GetLongArrayElements(jout_num_col, 0); - out_num_row[0] = (jlong)num_row; - out_num_col[0] = (jlong)num_col; + jlong* out_num_row = jenv->GetLongArrayElements(jout_num_row, nullptr); + jlong* out_num_col = jenv->GetLongArrayElements(jout_num_col, nullptr); + jlong* out_nelem = jenv->GetLongArrayElements(jout_nelem, nullptr); + out_num_row[0] = static_cast(num_row); + out_num_col[0] = static_cast(num_col); + out_nelem[0] = static_cast(num_elem); + // release arrays jenv->ReleaseLongArrayElements(jout_num_row, out_num_row, 0); jenv->ReleaseLongArrayElements(jout_num_col, out_num_col, 0); + jenv->ReleaseLongArrayElements(jout_nelem, out_nelem, 0); - return ret; + return static_cast(ret); +} + +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreeliteDMatrixFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixFree( + JNIEnv* jenv, jclass jcls, jlong jdmat) { + return static_cast(TreeliteDMatrixFree(reinterpret_cast(jdmat))); } /* @@ -147,124 +181,114 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteBatchGetDimension( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorLoad( - JNIEnv* jenv, jclass jcls, jstring jlibrary_path, jint jnum_worker_thread, - jlongArray jout) { - - const char* library_path = jenv->GetStringUTFChars(jlibrary_path, 0); - PredictorHandle out; - const jint ret = (jint)TreelitePredictorLoad(library_path, - (int)jnum_worker_thread, &out); + JNIEnv* jenv, jclass jcls, jstring jlibrary_path, jint jnum_worker_thread, jlongArray jout) { + const char* library_path = jenv->GetStringUTFChars(jlibrary_path, nullptr); + PredictorHandle out = nullptr; + const int ret = TreelitePredictorLoad(library_path, static_cast(jnum_worker_thread), &out); setHandle(jenv, jout, out); - return ret; + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorPredictBatch - * Signature: (JJZZZ[F)I + * Method: TreelitePredictorPredictBatchWithFloat32Out + * Signature: (JJZZ[F[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatch( - JNIEnv* jenv, jclass jcls, jlong jhandle, jlong jbatch, - jboolean jbatch_sparse, jboolean jverbose, jboolean jpred_margin, - jfloatArray jout_result, jlongArray jout_result_size) { - - jfloat* out_result = jenv->GetFloatArrayElements(jout_result, 0); - jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, 0); - size_t out_result_size_tmp; - const jint ret = (jint)TreelitePredictorPredictBatch( - (PredictorHandle)jhandle, (void*)jbatch, - (jbatch_sparse == JNI_TRUE ? 1 : 0), (jverbose == JNI_TRUE ? 1 : 0), - (jpred_margin == JNI_TRUE ? 1 : 0), (float*)out_result, - &out_result_size_tmp); - out_result_size[0] = (jlong)out_result_size_tmp; - +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithFloat32Out( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlong jbatch, jboolean jverbose, + jboolean jpred_margin, jfloatArray jout_result, jlongArray jout_result_size) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + DMatrixHandle dmat = reinterpret_cast(jbatch); + jfloat* out_result = jenv->GetFloatArrayElements(jout_result, nullptr); + jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, nullptr); + size_t out_result_size_tmp = 0; + const int ret = TreelitePredictorPredictBatch( + predictor, dmat, (jverbose == JNI_TRUE ? 1 : 0), + (jpred_margin == JNI_TRUE ? 1 : 0), static_cast(out_result), + &out_result_size_tmp); + out_result_size[0] = static_cast(out_result_size_tmp); // release arrays jenv->ReleaseFloatArrayElements(jout_result, out_result, 0); jenv->ReleaseLongArrayElements(jout_result_size, out_result_size, 0); - return ret; + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorPredictInst - * Signature: (J[BZ[F[J)I + * Method: TreelitePredictorPredictBatchWithFloat64Out + * Signature: (JJZZ[D[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictInst( - JNIEnv* jenv, jclass jcls, jlong jhandle, jbyteArray jinst, - jboolean jpred_margin, jfloatArray jout_result, jlongArray jout_result_size) { - - // read Entry[] array from bytes - jbyte* inst_bytes = jenv->GetByteArrayElements(jinst, 0); - const size_t nbytes = jenv->GetArrayLength(jinst); - CHECK_EQ(nbytes % sizeof(TreelitePredictorEntry), 0); - const size_t num_elem = nbytes / sizeof(TreelitePredictorEntry); - if (!DMLC_LITTLE_ENDIAN) { // re-order bytes on big-endian machines - dmlc::ByteSwap((void*)inst_bytes, nbytes, num_elem); - } - dmlc::MemoryFixedSizeStream fs((void*)inst_bytes, nbytes); - std::vector inst(num_elem, {-1}); - for (int i = 0; i < num_elem; ++i) { - fs.Read(&inst[i], sizeof(TreelitePredictorEntry)); - } - - jfloat* out_result = jenv->GetFloatArrayElements(jout_result, 0); - jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, 0); - size_t out_result_size_tmp; - const jint ret = (jint)TreelitePredictorPredictInst((PredictorHandle)jhandle, - inst.data(), (jpred_margin == JNI_TRUE ? 1 : 0), - (float*)out_result, &out_result_size_tmp); - out_result_size[0] = (jlong)out_result_size_tmp; - +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithFloat64Out( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlong jbatch, jboolean jverbose, + jboolean jpred_margin, jdoubleArray jout_result, jlongArray jout_result_size) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + DMatrixHandle dmat = reinterpret_cast(jbatch); + jdouble* out_result = jenv->GetDoubleArrayElements(jout_result, nullptr); + jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, nullptr); + size_t out_result_size_tmp = 0; + const int ret = TreelitePredictorPredictBatch( + predictor, dmat, (jverbose == JNI_TRUE ? 1 : 0), + (jpred_margin == JNI_TRUE ? 1 : 0), static_cast(out_result), + &out_result_size_tmp); + out_result_size[0] = static_cast(out_result_size_tmp); // release arrays - jenv->ReleaseByteArrayElements(jinst, inst_bytes, 0); - jenv->ReleaseFloatArrayElements(jout_result, out_result, 0); + jenv->ReleaseDoubleArrayElements(jout_result, out_result, 0); jenv->ReleaseLongArrayElements(jout_result_size, out_result_size, 0); - return ret; + return static_cast(ret); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorQueryResultSize - * Signature: (JJZ[J)I + * Method: TreelitePredictorPredictBatchWithUInt32Out + * Signature: (JJZZ[I[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSize( - JNIEnv* jenv, jclass jcls, jlong jhandle, jlong jbatch, - jboolean jbatch_sparse, jlongArray jout) { - - size_t result_size; - const jint ret = (jint)TreelitePredictorQueryResultSize( - (PredictorHandle)jhandle, (void*)jbatch, - (jbatch_sparse == JNI_TRUE ? 1 : 0), &result_size); - // store dimension - jlong* out = jenv->GetLongArrayElements(jout, 0); - out[0] = (jlong)result_size; - jenv->ReleaseLongArrayElements(jout, out, 0); +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithUInt32Out( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlong jbatch, jboolean jverbose, + jboolean jpred_margin, jintArray jout_result, jlongArray jout_result_size) { + API_BEGIN(); + PredictorHandle predictor = reinterpret_cast(jpredictor); + DMatrixHandle dmat = reinterpret_cast(jbatch); + CHECK_EQ(sizeof(jint), sizeof(uint32_t)); + jint* out_result = jenv->GetIntArrayElements(jout_result, nullptr); + jlong* out_result_size = jenv->GetLongArrayElements(jout_result_size, nullptr); + size_t out_result_size_tmp = 0; + const int ret = TreelitePredictorPredictBatch( + predictor, dmat, (jverbose == JNI_TRUE ? 1 : 0), + (jpred_margin == JNI_TRUE ? 1 : 0), static_cast(out_result), + &out_result_size_tmp); + out_result_size[0] = static_cast(out_result_size_tmp); + // release arrays + jenv->ReleaseIntArrayElements(jout_result, out_result, 0); + jenv->ReleaseLongArrayElements(jout_result_size, out_result_size, 0); - return ret; + return static_cast(ret); + API_END(); } /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorQueryResultSizeSingleInst - * Signature: (J[J)I + * Method: TreelitePredictorQueryResultSize + * Signature: (JJ[J)I */ JNIEXPORT jint JNICALL -Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSizeSingleInst( - JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { - size_t result_size; - const jint ret = (jint)TreelitePredictorQueryResultSizeSingleInst( - (PredictorHandle)jhandle, &result_size); +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSize( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlong jbatch, jlongArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + DMatrixHandle dmat = reinterpret_cast(jbatch); + size_t result_size = 0; + const int ret = TreelitePredictorQueryResultSize(predictor, dmat, &result_size); // store dimension - jlong* out = jenv->GetLongArrayElements(jout, 0); - out[0] = (jlong)result_size; + jlong* out = jenv->GetLongArrayElements(jout, nullptr); + out[0] = static_cast(result_size); jenv->ReleaseLongArrayElements(jout, out, 0); - return ret; + + return static_cast(ret); } /* @@ -274,17 +298,16 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSizeSingleI */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryNumOutputGroup( - JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { - - size_t num_output_group; - const jint ret = (jint)TreelitePredictorQueryNumOutputGroup( - (PredictorHandle)jhandle, &num_output_group); + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlongArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + size_t num_output_group = 0; + const int ret = TreelitePredictorQueryNumOutputGroup(predictor, &num_output_group); // store dimension - jlong* out = jenv->GetLongArrayElements(jout, 0); - out[0] = (jlong)num_output_group; + jlong* out = jenv->GetLongArrayElements(jout, nullptr); + out[0] = static_cast(num_output_group); jenv->ReleaseLongArrayElements(jout, out, 0); - return ret; + return static_cast(ret); } /* @@ -294,17 +317,16 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryNumOutputGroup( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryNumFeature( - JNIEnv* jenv, jclass jcls, jlong jhandle, jlongArray jout) { - - size_t num_feature; - const jint ret = (jint)TreelitePredictorQueryNumFeature( - (PredictorHandle)jhandle, &num_feature); + JNIEnv* jenv, jclass jcls, jlong jpredictor, jlongArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + size_t num_feature = 0; + const int ret = TreelitePredictorQueryNumFeature(predictor, &num_feature); // store dimension - jlong* out = jenv->GetLongArrayElements(jout, 0); - out[0] = (jlong)num_feature; + jlong* out = jenv->GetLongArrayElements(jout, nullptr); + out[0] = static_cast(num_feature); jenv->ReleaseLongArrayElements(jout, out, 0); - return ret; + return static_cast(ret); } /* @@ -314,11 +336,10 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryNumFeature( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryPredTransform( - JNIEnv* jenv, jclass jcls, jlong jhandle, jobjectArray jout) { - - const char* pred_transform; - const jint ret = (jint)TreelitePredictorQueryPredTransform( - (PredictorHandle)jhandle, &pred_transform); + JNIEnv* jenv, jclass jcls, jlong jpredictor, jobjectArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + const char* pred_transform = nullptr; + const int ret = TreelitePredictorQueryPredTransform(predictor, &pred_transform); // store data jstring out = nullptr; if (pred_transform != nullptr) { @@ -326,7 +347,7 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryPredTransform( } jenv->SetObjectArrayElement(jout, 0, out); - return ret; + return static_cast(ret); } /* @@ -336,17 +357,16 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryPredTransform( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQuerySigmoidAlpha( - JNIEnv* jenv, jclass jcls, jlong jhandle, jfloatArray jout) { - - float alpha; - const jint ret = (jint)TreelitePredictorQuerySigmoidAlpha( - (PredictorHandle)jhandle, &alpha); + JNIEnv* jenv, jclass jcls, jlong jpredictor, jfloatArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + float alpha = std::numeric_limits::quiet_NaN(); + const int ret = TreelitePredictorQuerySigmoidAlpha(predictor, &alpha); // store data - jfloat* out = jenv->GetFloatArrayElements(jout, 0); - out[0] = (jlong)alpha; + jfloat* out = jenv->GetFloatArrayElements(jout, nullptr); + out[0] = static_cast(alpha); jenv->ReleaseFloatArrayElements(jout, out, 0); - return ret; + return static_cast(ret); } /* @@ -356,16 +376,57 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQuerySigmoidAlpha( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryGlobalBias( - JNIEnv* jenv, jclass jcls, jlong jhandle, jfloatArray jout) { - - float bias; - const jint ret = (jint)TreelitePredictorQueryGlobalBias( - (PredictorHandle)jhandle, &bias); + JNIEnv* jenv, jclass jcls, jlong jpredictor, jfloatArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + float bias = std::numeric_limits::quiet_NaN(); + const int ret = TreelitePredictorQueryGlobalBias(predictor, &bias); // store data - jfloat* out = jenv->GetFloatArrayElements(jout, 0); - out[0] = (jlong)bias; + jfloat* out = jenv->GetFloatArrayElements(jout, nullptr); + out[0] = static_cast(bias); jenv->ReleaseFloatArrayElements(jout, out, 0); + return static_cast(ret); +} + +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreelitePredictorQueryThresholdType + * Signature: (J[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryThresholdType( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jobjectArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + const char* threshold_type = nullptr; + const int ret = TreelitePredictorQueryThresholdType(predictor, &threshold_type); + // store data + jstring out = nullptr; + if (threshold_type != nullptr) { + out = jenv->NewStringUTF(threshold_type); + } + jenv->SetObjectArrayElement(jout, 0, out); + + return static_cast(ret); +} + +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreelitePredictorQueryLeafOutputType + * Signature: (J[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryLeafOutputType( + JNIEnv* jenv, jclass jcls, jlong jpredictor, jobjectArray jout) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + const char* leaf_output_type = nullptr; + const jint ret = (jint)TreelitePredictorQueryLeafOutputType(predictor, &leaf_output_type); + // store data + jstring out = nullptr; + if (leaf_output_type != nullptr) { + out = jenv->NewStringUTF(leaf_output_type); + } + jenv->SetObjectArrayElement(jout, 0, out); + return ret; } @@ -376,6 +437,7 @@ Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryGlobalBias( */ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorFree( - JNIEnv* jenv, jclass jcls, jlong jhandle) { - return (jint)TreelitePredictorFree((PredictorHandle)jhandle); + JNIEnv* jenv, jclass jcls, jlong jpredictor) { + PredictorHandle predictor = reinterpret_cast(jpredictor); + return static_cast(TreelitePredictorFree(predictor)); } diff --git a/runtime/java/treelite4j/src/native/treelite4j.h b/runtime/java/treelite4j/src/native/treelite4j.h index 857030f4..a461f419 100644 --- a/runtime/java/treelite4j/src/native/treelite4j.h +++ b/runtime/java/treelite4j/src/native/treelite4j.h @@ -17,43 +17,57 @@ JNIEXPORT jstring JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteGetLa /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteAssembleSparseBatch + * Method: TreeliteDMatrixCreateFromCSRWithFloat32In * Signature: ([F[I[JJJ[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteAssembleSparseBatch +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromCSRWithFloat32In (JNIEnv *, jclass, jfloatArray, jintArray, jlongArray, jlong, jlong, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteDeleteSparseBatch - * Signature: (J[F[I[J)I + * Method: TreeliteDMatrixCreateFromCSRWithFloat64In + * Signature: ([D[I[JJJ[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDeleteSparseBatch - (JNIEnv *, jclass, jlong, jfloatArray, jintArray, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromCSRWithFloat64In + (JNIEnv *, jclass, jdoubleArray, jintArray, jlongArray, jlong, jlong, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteAssembleDenseBatch - * Signature: ([FFJJ[J)I + * Method: TreeliteDMatrixCreateFromMatWithFloat32In + * Signature: ([FJJF[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteAssembleDenseBatch - (JNIEnv *, jclass, jfloatArray, jfloat, jlong, jlong, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromMatWithFloat32In + (JNIEnv *, jclass, jfloatArray, jlong, jlong, jfloat, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteDeleteDenseBatch - * Signature: (J[F)I + * Method: TreeliteDMatrixCreateFromMatWithFloat64In + * Signature: ([DJJD[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDeleteDenseBatch - (JNIEnv *, jclass, jlong, jfloatArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixCreateFromMatWithFloat64In + (JNIEnv *, jclass, jdoubleArray, jlong, jlong, jdouble, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreeliteBatchGetDimension - * Signature: (JZ[J[J)I + * Method: TreeliteDMatrixGetDimension + * Signature: (J[J[J[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteBatchGetDimension - (JNIEnv *, jclass, jlong, jboolean, jlongArray, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixGetDimension + (JNIEnv *, jclass, jlong, jlongArray, jlongArray, jlongArray); + +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreeliteDMatrixFree + * Signature: (J)I + */ +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreeliteDMatrixFree + (JNIEnv *, jclass, jlong); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI @@ -65,35 +79,38 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredicto /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorPredictBatch - * Signature: (JJZZZ[F[J)I + * Method: TreelitePredictorPredictBatchWithFloat32Out + * Signature: (JJZZ[F[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatch - (JNIEnv *, jclass, jlong, jlong, jboolean, jboolean, jboolean, jfloatArray, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithFloat32Out + (JNIEnv *, jclass, jlong, jlong, jboolean, jboolean, jfloatArray, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorPredictInst - * Signature: (J[BZ[F[J)I + * Method: TreelitePredictorPredictBatchWithFloat64Out + * Signature: (JJZZ[D[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictInst - (JNIEnv *, jclass, jlong, jbyteArray, jboolean, jfloatArray, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithFloat64Out + (JNIEnv *, jclass, jlong, jlong, jboolean, jboolean, jdoubleArray, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorQueryResultSize - * Signature: (JJZ[J)I + * Method: TreelitePredictorPredictBatchWithUInt32Out + * Signature: (JJZZ[I[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSize - (JNIEnv *, jclass, jlong, jlong, jboolean, jlongArray); +JNIEXPORT jint JNICALL +Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorPredictBatchWithUInt32Out + (JNIEnv *, jclass, jlong, jlong, jboolean, jboolean, jintArray, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI - * Method: TreelitePredictorQueryResultSizeSingleInst - * Signature: (J[J)I + * Method: TreelitePredictorQueryResultSize + * Signature: (JJ[J)I */ -JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSizeSingleInst - (JNIEnv *, jclass, jlong, jlongArray); +JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryResultSize + (JNIEnv *, jclass, jlong, jlong, jlongArray); /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI @@ -135,6 +152,22 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredicto JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryGlobalBias (JNIEnv *, jclass, jlong, jfloatArray); +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreelitePredictorQueryThresholdType + * Signature: (J[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryThresholdType + (JNIEnv *, jclass, jlong, jobjectArray); + +/* + * Class: ml_dmlc_treelite4j_java_TreeliteJNI + * Method: TreelitePredictorQueryLeafOutputType + * Signature: (J[Ljava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ml_dmlc_treelite4j_java_TreeliteJNI_TreelitePredictorQueryLeafOutputType + (JNIEnv *, jclass, jlong, jobjectArray); + /* * Class: ml_dmlc_treelite4j_java_TreeliteJNI * Method: TreelitePredictorFree diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java new file mode 100644 index 00000000..487805fb --- /dev/null +++ b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java @@ -0,0 +1,87 @@ +package ml.dmlc.treelite4j.java; + +import junit.framework.TestCase; +import ml.dmlc.treelite4j.DataPoint; +import org.apache.commons.lang.ArrayUtils; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Test cases for data matrix + * + * @author Hyunsu Cho + */ +public class DMatrixTest { + @Test + public void testDenseDMatrixBasicFloat32() throws TreeliteError { + for (int i = 0; i < 1000; ++i) { + int num_row = ThreadLocalRandom.current().nextInt(1, 100); + int num_col = ThreadLocalRandom.current().nextInt(1, 100); + float[] data = new float[num_row * num_col]; + for (int k = 0; k < num_row * num_col; ++k) { + data[k] = ThreadLocalRandom.current().nextFloat() - 0.5f; + } + DMatrix dmat = new DMatrix(data, Float.NaN, num_row, num_col); + TestCase.assertEquals(num_row, dmat.getNumRow()); + TestCase.assertEquals(num_col, dmat.getNumCol()); + TestCase.assertEquals(num_row * num_col, dmat.getNumElements()); + } + } + + @Test + public void testSparseDMatrixBasicFloat32() throws TreeliteError { + float kDensity = 0.1f; // % of nonzeros in matrix + float kProbNextRow = 0.1f; // transition probability from one row to next + for (int case_id = 0; case_id < 1000; ++case_id) { + int num_row = ThreadLocalRandom.current().nextInt(1, 100); + int num_col = ThreadLocalRandom.current().nextInt(1, 100); + int nnz = (int)(num_row * num_col * kDensity); + float[] data = new float[nnz]; + int[] col_ind = new int[nnz]; + ArrayList row_ptr = new ArrayList(); + row_ptr.add(0L); + for (int k = 0; k < nnz; ++k) { + data[k] = ThreadLocalRandom.current().nextFloat() - 0.5f; + col_ind[k] = ThreadLocalRandom.current().nextInt(0, num_col); + if (ThreadLocalRandom.current().nextFloat() < kProbNextRow + && row_ptr.size() < num_row) { + Arrays.sort(col_ind, row_ptr.get(row_ptr.size() - 1).intValue(), k + 1); + row_ptr.add(k + 1L); + } + } + Arrays.sort(col_ind, row_ptr.get(row_ptr.size() - 1).intValue(), nnz); + row_ptr.add((long)nnz); + while (row_ptr.size() < num_row + 1) { + row_ptr.add(row_ptr.get(row_ptr.size() - 1)); + } + TestCase.assertEquals(row_ptr.size(), num_row + 1); + long[] row_ptr_arr = ArrayUtils.toPrimitive(row_ptr.toArray(new Long[0])); + DMatrix dmat = new DMatrix(data, col_ind, row_ptr_arr, num_row, num_col); + TestCase.assertEquals(num_row, dmat.getNumRow()); + TestCase.assertEquals(num_col, dmat.getNumCol()); + TestCase.assertEquals(nnz, dmat.getNumElements()); + } + } + + @Test + public void testSparseDMatrixBuilder() throws TreeliteError, IOException { + List data_list = new ArrayList() { + { + add(new DataPoint(new int[]{0, 1}, new float[]{10f, 20f})); + add(new DataPoint(new int[]{1, 3}, new float[]{30f, 40f})); + add(new DataPoint(new int[]{2, 3, 4}, new float[]{50f, 60f, 70f})); + add(new DataPoint(new int[]{5}, new float[]{80f})); + } + }; + DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(data_list.iterator()); + + // should get 4-by-6 matrix + TestCase.assertEquals(4, dmat.getNumRow()); + TestCase.assertEquals(6, dmat.getNumCol()); + } +} diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DenseBatchTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DenseBatchTest.java deleted file mode 100644 index 2a8e7457..00000000 --- a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DenseBatchTest.java +++ /dev/null @@ -1,32 +0,0 @@ -package ml.dmlc.treelite4j.java; - -import junit.framework.TestCase; -import org.junit.Test; - -import java.util.concurrent.ThreadLocalRandom; - -/** - * Test cases for dense batch - * - * @author Hyunsu Cho - */ -public class DenseBatchTest { - @Test - public void testDenseBatchBasic() throws TreeliteError { - for (int i = 0; i < 100000; ++i) { - int num_row = ThreadLocalRandom.current().nextInt(1, 100); - int num_col = ThreadLocalRandom.current().nextInt(1, 100); - float[] data = new float[num_row * num_col]; - for (int k = 0; k < num_row * num_col; ++k) { - data[k] = ThreadLocalRandom.current().nextFloat() - 0.5f; - } - DenseBatch batch = new DenseBatch(data, Float.NaN, num_row, num_col); - long[] out_num_row = new long[1]; - long[] out_num_col = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreeliteBatchGetDimension( - batch.getHandle(), false, out_num_row, out_num_col)); - TestCase.assertEquals(num_row, out_num_row[0]); - TestCase.assertEquals(num_col, out_num_col[0]); - } - } -} diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java index 79a84548..92e8c118 100644 --- a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java +++ b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java @@ -49,21 +49,18 @@ public void testPredictorBasic() throws TreeliteError { public void testPredict() throws TreeliteError, IOException { Predictor predictor = new Predictor(mushroomLibLocation, -1, true); List dmat - = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); - SparseBatch sparse_batch = BatchBuilder.CreateSparseBatch(dmat.iterator()); - DenseBatch dense_batch = BatchBuilder.CreateDenseBatch(dmat.iterator()); - float[] expected_result - = LoadArrayFromText(mushroomTestDataPredProbResultLocation); + = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); + DMatrix sparse_dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()); + DMatrix dense_dmat = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()); + float[] expected_result = LoadArrayFromText(mushroomTestDataPredProbResultLocation); - /* sparse batch */ - float[][] result = predictor.predict(sparse_batch, true, false); + float[][] result = predictor.predict(sparse_dmat, true, false).toFloatMatrix(); for (int i = 0; i < result.length; ++i) { TestCase.assertEquals(1, result[i].length); TestCase.assertEquals(expected_result[i], result[i][0]); } - /* dense batch */ - result = predictor.predict(dense_batch, true, false); + result = predictor.predict(dense_dmat, true, false).toFloatMatrix(); for (int i = 0; i < result.length; ++i) { TestCase.assertEquals(1, result[i].length); TestCase.assertEquals(expected_result[i], result[i][0]); @@ -74,33 +71,27 @@ public void testPredict() throws TreeliteError, IOException { public void testPredictMargin() throws TreeliteError, IOException { Predictor predictor = new Predictor(mushroomLibLocation, -1, true); List dmat - = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); - SparseBatch sparse_batch = BatchBuilder.CreateSparseBatch(dmat.iterator()); - DenseBatch dense_batch = BatchBuilder.CreateDenseBatch(dmat.iterator()); + = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); + DMatrix sparse_batch = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()); + DMatrix dense_batch = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()); float[] expected_result = LoadArrayFromText(mushroomTestDataPredMarginResultLocation); /* sparse batch */ - float[][] result = predictor.predict(sparse_batch, true, true); + float[][] result = predictor.predict(sparse_batch, true, true).toFloatMatrix(); for (int i = 0; i < result.length; ++i) { TestCase.assertEquals(1, result[i].length); TestCase.assertEquals(expected_result[i], result[i][0]); } /* dense batch */ - result = predictor.predict(dense_batch, true, true); + result = predictor.predict(dense_batch, true, true).toFloatMatrix(); for (int i = 0; i < result.length; ++i) { TestCase.assertEquals(1, result[i].length); TestCase.assertEquals(expected_result[i], result[i][0]); } } - @Test - public void testPredictInst() throws TreeliteError, IOException { - Predictor predictor = new Predictor(mushroomLibLocation, -1, true); - mushroomLibPredictionTest(predictor); - } - @Test public void testSerialization() throws TreeliteError, IOException, ClassNotFoundException { Predictor predictor = new Predictor(mushroomLibLocation, -1, true); @@ -110,34 +101,15 @@ public void testSerialization() throws TreeliteError, IOException, ClassNotFound TestCase.assertEquals(predictor.GetPredTransform(), predictor2.GetPredTransform()); TestCase.assertEquals(predictor.GetSigmoidAlpha(), predictor2.GetSigmoidAlpha()); TestCase.assertEquals(predictor.GetGlobalBias(), predictor2.GetGlobalBias()); - mushroomLibPredictionTest(predictor2); - } - - private void mushroomLibPredictionTest(Predictor predictor) throws IOException, TreeliteError { - Entry[] inst_arr = new Entry[predictor.GetNumFeature()]; - for (int i = 0; i < inst_arr.length; ++i) { - inst_arr[i] = new Entry(); - inst_arr[i].setMissing(); - } - - float[] expected_result - = LoadArrayFromText(mushroomTestDataPredProbResultLocation); - List dmat - = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); - int row_id = 0; - for (DataPoint inst : dmat) { - int[] indices = inst.indices(); - float[] values = inst.values(); - for (int i = 0; i < indices.length; ++i) { - inst_arr[indices[i]].setFValue(values[i]); - } - float[] result = predictor.predict(inst_arr, false); - TestCase.assertEquals(1, result.length); - TestCase.assertEquals(expected_result[row_id++], result[0]); - for (int i = 0; i < inst_arr.length; ++i) { - inst_arr[i].setMissing(); - } + List dataset + = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); + DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(dataset.iterator()); + float[] expected_result = LoadArrayFromText(mushroomTestDataPredProbResultLocation); + float[][] result = predictor.predict(dmat, true, false).toFloatMatrix(); + for (int i = 0; i < result.length; ++i) { + TestCase.assertEquals(1, result[i].length); + TestCase.assertEquals(expected_result[i], result[i][0]); } } diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/SparseBatchTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/SparseBatchTest.java deleted file mode 100644 index 1a73b469..00000000 --- a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/SparseBatchTest.java +++ /dev/null @@ -1,74 +0,0 @@ -package ml.dmlc.treelite4j.java; - -import junit.framework.TestCase; -import ml.dmlc.treelite4j.DataPoint; -import org.apache.commons.lang.ArrayUtils; -import org.junit.Test; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ThreadLocalRandom; - -/** - * Test cases for sparse batch - * - * @author Hyunsu Cho - */ -public class SparseBatchTest { - @Test - public void testSparseBatchBasic() throws TreeliteError { - float kDensity = 0.1f; // % of nonzeros in matrix - float kProbNextRow = 0.1f; // transition probability from one row to next - for (int case_id = 0; case_id < 100000; ++case_id) { - int num_row = ThreadLocalRandom.current().nextInt(1, 100); - int num_col = ThreadLocalRandom.current().nextInt(1, 100); - int nnz = (int)(num_row * num_col * kDensity); - float[] data = new float[nnz]; - int[] col_ind = new int[nnz]; - ArrayList row_ptr = new ArrayList(); - row_ptr.add(0L); - for (int k = 0; k < data.length; ++k) { - data[k] = ThreadLocalRandom.current().nextFloat() - 0.5f; - col_ind[k] = ThreadLocalRandom.current().nextInt(0, num_col); - if (ThreadLocalRandom.current().nextFloat() < kProbNextRow - || k == data.length - 1) { - Arrays.sort(col_ind, row_ptr.get(row_ptr.size()-1).intValue(), k+1); - row_ptr.add(k + 1L); - } - } - long[] row_ptr_arr - = ArrayUtils.toPrimitive(row_ptr.toArray(new Long[0])); - SparseBatch batch - = new SparseBatch(data, col_ind, row_ptr_arr, num_row, num_col); - long[] out_num_row = new long[1]; - long[] out_num_col = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreeliteBatchGetDimension( - batch.getHandle(), true, out_num_row, out_num_col)); - TestCase.assertEquals(num_row, out_num_row[0]); - TestCase.assertEquals(num_col, out_num_col[0]); - } - } - - @Test - public void testSparseBatchBuilder() throws TreeliteError, IOException { - List dmat = new ArrayList() { - { - add(new DataPoint(new int[]{0, 1}, new float[]{10f, 20f})); - add(new DataPoint(new int[]{1, 3}, new float[]{30f, 40f})); - add(new DataPoint(new int[]{2, 3, 4}, new float[]{50f, 60f, 70f})); - add(new DataPoint(new int[]{5}, new float[]{80f})); - } - }; - SparseBatch batch = BatchBuilder.CreateSparseBatch(dmat.iterator()); - - // should get 4-by-6 matrix - long[] out_num_row = new long[1]; - long[] out_num_col = new long[1]; - TreeliteJNI.checkCall(TreeliteJNI.TreeliteBatchGetDimension( - batch.getHandle(), true, out_num_row, out_num_col)); - TestCase.assertEquals(4, out_num_row[0]); - TestCase.assertEquals(6, out_num_col[0]); - } -} diff --git a/runtime/java/treelite4j/src/test/resources/mushroom_example/mushroom.c b/runtime/java/treelite4j/src/test/resources/mushroom_example/mushroom.c index e8276624..6c909a7c 100644 --- a/runtime/java/treelite4j/src/test/resources/mushroom_example/mushroom.c +++ b/runtime/java/treelite4j/src/test/resources/mushroom_example/mushroom.c @@ -29,6 +29,14 @@ float get_global_bias(void) { return -0; } +const char* get_threshold_type(void) { + return "float32"; +} + +const char* get_leaf_output_type(void) { + return "float32"; +} + static inline float pred_transform(float margin) { const float alpha = (float)1; return 1.0f / (1 + expf(-alpha * margin)); diff --git a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala index b80a36b7..c255c744 100644 --- a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala +++ b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala @@ -1,7 +1,7 @@ package ml.dmlc.treelite4j.scala import ml.dmlc.treelite4j.java.PredictorTest.LoadArrayFromText -import ml.dmlc.treelite4j.java.{BatchBuilder, Entry, NativeLibLoader} +import ml.dmlc.treelite4j.java.{DMatrixBuilder, Entry, NativeLibLoader} import org.scalatest.{FunSuite, Matchers} import scala.collection.JavaConverters._ @@ -27,16 +27,16 @@ class PredictorTest extends FunSuite with Matchers { test("PredictBatch") { val predictor = Predictor(mushroomLibLocation) - val dmat = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation) - val sparseBatch = BatchBuilder.CreateSparseBatch(dmat.iterator()) - val denseBatch = BatchBuilder.CreateDenseBatch(dmat.iterator()) + val dmat = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation) + val sparseDMatrix = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()) + val denseDMatrix = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()) val retProb = LoadArrayFromText(mushroomTestDataPredProbResultLocation) val retMargin = LoadArrayFromText(mushroomTestDataPredMarginResultLocation) - val sparseMargin = predictor.predictSparseBatch(sparseBatch, predMargin = true) - val sparseProb = predictor.predictSparseBatch(sparseBatch) - val denseMargin = predictor.predictDenseBatch(denseBatch, predMargin = true) - val denseProb = predictor.predictDenseBatch(denseBatch) + val sparseMargin = predictor.predictBatch(sparseDMatrix, predMargin = true).toFloatMatrix + val sparseProb = predictor.predictBatch(sparseDMatrix).toFloatMatrix + val denseMargin = predictor.predictBatch(denseDMatrix, predMargin = true).toFloatMatrix + val denseProb = predictor.predictBatch(denseDMatrix).toFloatMatrix retProb.zip(denseProb.zip(sparseProb)).foreach { case (ret, (dense, sparse)) => Seq(dense.length, sparse.length) shouldEqual Seq(1, 1) @@ -47,26 +47,4 @@ class PredictorTest extends FunSuite with Matchers { Seq(dense.head, sparse.head) shouldEqual Seq(ret, ret) } } - - test("PredictInst") { - val predictor = Predictor(mushroomLibLocation) - mushroomLibPredictionTest(predictor) - } - - private def mushroomLibPredictionTest(predictor: Predictor): Unit = { - val instArray = Array.tabulate(predictor.numFeature)(_ => { - val entry = new Entry() - entry.setMissing() - entry - }) - val expectedResult = LoadArrayFromText(mushroomTestDataPredProbResultLocation) - val dataPoints = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation) - dataPoints.asScala.zipWithIndex.foreach { case (dp, row) => - dp.indices.zip(dp.values).foreach { case (i, v) => instArray(i).setFValue(v) } - val result = predictor.predictInst(instArray) - result.length shouldEqual 1 - result(0) shouldEqual expectedResult(row) - instArray.foreach(_.setMissing()) - } - } } diff --git a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala index 74f60b40..e1fa91a8 100644 --- a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala +++ b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala @@ -1,7 +1,7 @@ package ml.dmlc.treelite4j.scala import ml.dmlc.treelite4j.java.PredictorTest.LoadArrayFromText -import ml.dmlc.treelite4j.java.{BatchBuilder, NativeLibLoader} +import ml.dmlc.treelite4j.java.{DMatrixBuilder, NativeLibLoader} import ml.dmlc.treelite4j.scala.spark.TreeLiteModel import org.apache.spark.SparkContext import org.apache.spark.ml.linalg._ @@ -62,7 +62,7 @@ class TreeLiteModelTest extends FunSuite with Matchers with BeforeAndAfterEach { private def buildDataFrame(numPartitions: Int = numWorkers): DataFrame = { val probResult = LoadArrayFromText(mushroomTestDataPredProbResultLocation) val marginResult = LoadArrayFromText(mushroomTestDataPredMarginResultLocation) - val dataPoint = BatchBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation).asScala + val dataPoint = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation).asScala val localData = dataPoint.zip(probResult.zip(marginResult)).map { case (dp, (prob, margin)) => diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b91342e5..f7ba7c87 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -116,11 +116,11 @@ target_sources(objtreelite_common PRIVATE c_api/c_api_common.cc c_api/c_api_error.cc - c_api/c_api_error.h data.cc logging.cc typeinfo.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_common.h + ${PROJECT_SOURCE_DIR}/include/treelite/c_api_error.h ${PROJECT_SOURCE_DIR}/include/treelite/logging.h ${PROJECT_SOURCE_DIR}/include/treelite/math.h ${PROJECT_SOURCE_DIR}/include/treelite/typeinfo.h diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 4dfe98df..55b11503 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -18,7 +19,6 @@ #include #include #include -#include "./c_api_error.h" using namespace treelite; diff --git a/src/c_api/c_api_common.cc b/src/c_api/c_api_common.cc index 77c1110f..76a36f59 100644 --- a/src/c_api/c_api_common.cc +++ b/src/c_api/c_api_common.cc @@ -9,7 +9,7 @@ #include #include #include -#include "./c_api_error.h" +#include using namespace treelite; diff --git a/src/c_api/c_api_error.cc b/src/c_api/c_api_error.cc index 34515643..8ab205e2 100644 --- a/src/c_api/c_api_error.cc +++ b/src/c_api/c_api_error.cc @@ -5,7 +5,7 @@ * \brief C error handling */ #include -#include "./c_api_error.h" +#include struct TreeliteAPIErrorEntry { std::string last_error; diff --git a/src/c_api/c_api_runtime.cc b/src/c_api/c_api_runtime.cc index e5d5482a..94c4b3c3 100644 --- a/src/c_api/c_api_runtime.cc +++ b/src/c_api/c_api_runtime.cc @@ -7,10 +7,10 @@ #include #include +#include #include #include #include -#include "./c_api_error.h" using namespace treelite; From 1827bdbaf286685f04856b8f52faaa125cc2ec00 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 21:09:09 -0700 Subject: [PATCH 31/38] Fix lint --- include/treelite/c_api_error.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/treelite/c_api_error.h b/include/treelite/c_api_error.h index 6f6b6108..28abd959 100644 --- a/include/treelite/c_api_error.h +++ b/include/treelite/c_api_error.h @@ -4,8 +4,8 @@ * \author Hyunsu Cho * \brief Error handling for C API. */ -#ifndef TREELITE_C_API_C_API_ERROR_H_ -#define TREELITE_C_API_C_API_ERROR_H_ +#ifndef TREELITE_C_API_ERROR_H_ +#define TREELITE_C_API_ERROR_H_ #include #include @@ -46,4 +46,4 @@ inline int TreeliteAPIHandleException(const std::exception &e) { TreeliteAPISetLastError(e.what()); return -1; } -#endif // TREELITE_C_API_C_API_ERROR_H_ +#endif // TREELITE_C_API_ERROR_H_ From f61153369010bdab11a5b520706cd1d69fcc4d44 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 4 Sep 2020 21:24:07 -0700 Subject: [PATCH 32/38] Fix cmake import example --- tests/example_app/CMakeLists.txt | 2 +- tests/example_app/example.cc | 20 +++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/example_app/CMakeLists.txt b/tests/example_app/CMakeLists.txt index 574442b0..bb457c32 100644 --- a/tests/example_app/CMakeLists.txt +++ b/tests/example_app/CMakeLists.txt @@ -14,6 +14,6 @@ add_executable(example example.cc) target_link_libraries(example PRIVATE treelite::treelite_static treelite::treelite_runtime_static) set_target_properties(example PROPERTIES - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED YES ) diff --git a/tests/example_app/example.cc b/tests/example_app/example.cc index 5966ee04..c16e071c 100644 --- a/tests/example_app/example.cc +++ b/tests/example_app/example.cc @@ -8,28 +8,30 @@ using treelite::frontend::TreeBuilder; using treelite::frontend::ModelBuilder; +using treelite::frontend::Value; +using treelite::TypeInfo; int main(void) { - std::unique_ptr tree{new TreeBuilder}; + auto tree = std::make_unique(TypeInfo::kFloat32, TypeInfo::kFloat32); tree->CreateNode(0); tree->CreateNode(1); tree->CreateNode(2); - tree->SetNumericalTestNode(0, 0, "<", 0.0f, true, 1, 2); - tree->SetLeafNode(1, -1.0f); - tree->SetLeafNode(2, 1.0f); + tree->SetNumericalTestNode(0, 0, "<", Value::Create(0), true, 1, 2); + tree->SetLeafNode(1, Value::Create(-1.0)); + tree->SetLeafNode(2, Value::Create(1.0)); tree->SetRootNode(0); - std::unique_ptr builder{new ModelBuilder(2, 1, false)}; + auto builder + = std::make_unique(2, 1, false, TypeInfo::kFloat32, TypeInfo::kFloat32); builder->InsertTree(tree.get()); - treelite::Model model; - builder->CommitModel(&model); - std::cout << model.trees.size() << std::endl; + auto model = builder->CommitModel(); + std::cout << model->GetNumTree() << std::endl; treelite::compiler::CompilerParam param; param.Init(std::map{}); std::unique_ptr compiler{treelite::Compiler::Create("ast_native", param)}; - treelite::compiler::CompiledModel cm = compiler->Compile(model); + treelite::compiler::CompiledModel cm = compiler->Compile(*model.get()); for (const auto& kv : cm.files) { std::cout << "=================" << kv.first << "=================" << std::endl; std::cout << kv.second.content << std::endl; From 973b4a600b1cabf06902d6489b2659fd775ad89a Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Sat, 5 Sep 2020 16:24:14 -0700 Subject: [PATCH 33/38] Remove 'float32' suffix in Java methods --- .../dmlc/treelite4j/java/DMatrixBuilder.java | 7 +++---- .../scala/spark/TreeLiteModel.scala | 4 ++-- .../ml/dmlc/treelite4j/java/DMatrixTest.java | 2 +- .../dmlc/treelite4j/java/PredictorTest.java | 19 ++++++++----------- .../dmlc/treelite4j/scala/PredictorTest.scala | 6 +++--- .../treelite4j/scala/TreeLiteModelTest.scala | 2 +- 6 files changed, 18 insertions(+), 22 deletions(-) diff --git a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java index b9460772..2c555cac 100644 --- a/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java +++ b/runtime/java/treelite4j/src/main/java/ml/dmlc/treelite4j/java/DMatrixBuilder.java @@ -26,7 +26,7 @@ public class DMatrixBuilder { * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static DMatrix createSparseCSRDMatrixFloat32(Iterator dIter) + public static DMatrix createSparseCSRDMatrix(Iterator dIter) throws TreeliteError, IOException { ArrayList data = new ArrayList<>(); ArrayList col_ind = new ArrayList<>(); @@ -71,8 +71,7 @@ public static DMatrix createSparseCSRDMatrixFloat32(Iterator dIter) * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static DMatrix createDenseDMatrixFloat32( - Iterator dIter) + public static DMatrix createDenseDMatrix(Iterator dIter) throws TreeliteError, IOException { int num_row = 0; int num_col = 0; @@ -126,7 +125,7 @@ public static DMatrix createDenseDMatrixFloat32( * @throws TreeliteError Treelite error * @throws IOException IO error */ - public static List LoadDatasetFromLibSVMFloat32(String filename) + public static List LoadDatasetFromLibSVM(String filename) throws TreeliteError, IOException { File file = new File(filename); LineIterator it = FileUtils.lineIterator(file, "UTF-8"); diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala index 8bf16a61..53a2f357 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/scala/spark/TreeLiteModel.scala @@ -70,12 +70,12 @@ class TreeLiteModel private[spark]( } val result = batchRow.head.getAs[Vector]($(featuresCol)) match { case _: SparseVector => - val batch = DMatrixBuilder.createSparseCSRDMatrixFloat32(dataPoints.asJava) + val batch = DMatrixBuilder.createSparseCSRDMatrix(dataPoints.asJava) val ret = broadcastModel.value.predictBatch(batch, $(predictMargin), $(verbose)) batch.dispose() ret.toFloatMatrix.map(Row.apply(_)) case _: DenseVector => - val batch = DMatrixBuilder.createDenseDMatrixFloat32(dataPoints.asJava) + val batch = DMatrixBuilder.createDenseDMatrix(dataPoints.asJava) val ret = broadcastModel.value.predictBatch(batch, $(predictMargin), $(verbose)) batch.dispose() ret.toFloatMatrix.map(Row.apply(_)) diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java index 487805fb..38713166 100644 --- a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java +++ b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/DMatrixTest.java @@ -78,7 +78,7 @@ public void testSparseDMatrixBuilder() throws TreeliteError, IOException { add(new DataPoint(new int[]{5}, new float[]{80f})); } }; - DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(data_list.iterator()); + DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrix(data_list.iterator()); // should get 4-by-6 matrix TestCase.assertEquals(4, dmat.getNumRow()); diff --git a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java index 92e8c118..c18a08d8 100644 --- a/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java +++ b/runtime/java/treelite4j/src/test/java/ml/dmlc/treelite4j/java/PredictorTest.java @@ -48,10 +48,9 @@ public void testPredictorBasic() throws TreeliteError { @Test public void testPredict() throws TreeliteError, IOException { Predictor predictor = new Predictor(mushroomLibLocation, -1, true); - List dmat - = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); - DMatrix sparse_dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()); - DMatrix dense_dmat = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()); + List dmat = DMatrixBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); + DMatrix sparse_dmat = DMatrixBuilder.createSparseCSRDMatrix(dmat.iterator()); + DMatrix dense_dmat = DMatrixBuilder.createDenseDMatrix(dmat.iterator()); float[] expected_result = LoadArrayFromText(mushroomTestDataPredProbResultLocation); float[][] result = predictor.predict(sparse_dmat, true, false).toFloatMatrix(); @@ -70,10 +69,9 @@ public void testPredict() throws TreeliteError, IOException { @Test public void testPredictMargin() throws TreeliteError, IOException { Predictor predictor = new Predictor(mushroomLibLocation, -1, true); - List dmat - = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); - DMatrix sparse_batch = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()); - DMatrix dense_batch = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()); + List dmat = DMatrixBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); + DMatrix sparse_batch = DMatrixBuilder.createSparseCSRDMatrix(dmat.iterator()); + DMatrix dense_batch = DMatrixBuilder.createDenseDMatrix(dmat.iterator()); float[] expected_result = LoadArrayFromText(mushroomTestDataPredMarginResultLocation); @@ -102,9 +100,8 @@ public void testSerialization() throws TreeliteError, IOException, ClassNotFound TestCase.assertEquals(predictor.GetSigmoidAlpha(), predictor2.GetSigmoidAlpha()); TestCase.assertEquals(predictor.GetGlobalBias(), predictor2.GetGlobalBias()); - List dataset - = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation); - DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrixFloat32(dataset.iterator()); + List dataset = DMatrixBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation); + DMatrix dmat = DMatrixBuilder.createSparseCSRDMatrix(dataset.iterator()); float[] expected_result = LoadArrayFromText(mushroomTestDataPredProbResultLocation); float[][] result = predictor.predict(dmat, true, false).toFloatMatrix(); for (int i = 0; i < result.length; ++i) { diff --git a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala index c255c744..f82e3ca0 100644 --- a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala +++ b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/PredictorTest.scala @@ -27,9 +27,9 @@ class PredictorTest extends FunSuite with Matchers { test("PredictBatch") { val predictor = Predictor(mushroomLibLocation) - val dmat = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation) - val sparseDMatrix = DMatrixBuilder.createSparseCSRDMatrixFloat32(dmat.iterator()) - val denseDMatrix = DMatrixBuilder.createDenseDMatrixFloat32(dmat.iterator()) + val dmat = DMatrixBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation) + val sparseDMatrix = DMatrixBuilder.createSparseCSRDMatrix(dmat.iterator()) + val denseDMatrix = DMatrixBuilder.createDenseDMatrix(dmat.iterator()) val retProb = LoadArrayFromText(mushroomTestDataPredProbResultLocation) val retMargin = LoadArrayFromText(mushroomTestDataPredMarginResultLocation) diff --git a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala index e1fa91a8..229e7094 100644 --- a/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala +++ b/runtime/java/treelite4j/src/test/scala/ml/dmlc/treelite4j/scala/TreeLiteModelTest.scala @@ -62,7 +62,7 @@ class TreeLiteModelTest extends FunSuite with Matchers with BeforeAndAfterEach { private def buildDataFrame(numPartitions: Int = numWorkers): DataFrame = { val probResult = LoadArrayFromText(mushroomTestDataPredProbResultLocation) val marginResult = LoadArrayFromText(mushroomTestDataPredMarginResultLocation) - val dataPoint = DMatrixBuilder.LoadDatasetFromLibSVMFloat32(mushroomTestDataLocation).asScala + val dataPoint = DMatrixBuilder.LoadDatasetFromLibSVM(mushroomTestDataLocation).asScala val localData = dataPoint.zip(probResult.zip(marginResult)).map { case (dp, (prob, margin)) => From b1e1e4be96d2884c581dc7edca51780659714f90 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Thu, 10 Sep 2020 00:33:34 -0700 Subject: [PATCH 34/38] Remove a void* handle that's unused --- include/treelite/c_api_runtime.h | 3 --- .../src/main/scala/ml/dmlc/treelite4j/DataPoint.scala | 2 +- .../src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala | 2 +- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/include/treelite/c_api_runtime.h b/include/treelite/c_api_runtime.h index 17a8da4d..113128a5 100644 --- a/include/treelite/c_api_runtime.h +++ b/include/treelite/c_api_runtime.h @@ -21,9 +21,6 @@ */ /*! \brief handle to predictor class */ typedef void* PredictorHandle; -/*! \brief handle to output from predictor */ -typedef void* PredictorOutputHandle; - /*! \} */ /*! diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala index 644d59ec..3c6b4641 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPoint.scala @@ -25,4 +25,4 @@ case class DataPoint( values: Array[Float]) extends Serializable { require(indices == null || indices.length == values.length, "indices and values must have the same number of elements") -} \ No newline at end of file +} diff --git a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala index a91ae4ae..5447a373 100644 --- a/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala +++ b/runtime/java/treelite4j/src/main/scala/ml/dmlc/treelite4j/DataPointFloat64.scala @@ -11,4 +11,4 @@ case class DataPointFloat64( values: Array[Double]) extends Serializable { require(indices == null || indices.length == values.length, "indices and values must have the same number of elements") -} \ No newline at end of file +} From 6421fa21c90f627d60b82263f8412fd0b88c54cf Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Tue, 15 Sep 2020 00:25:17 -0700 Subject: [PATCH 35/38] New prediction runtime C API, to support multiple data types (Part of #196) (#199) * New prediction runtime C API, to support multiple data types * Address reviewer's feedback --- CMakeLists.txt | 2 + include/treelite/c_api_common.h | 68 ++ {src/c_api => include/treelite}/c_api_error.h | 6 +- include/treelite/c_api_runtime.h | 180 ++--- include/treelite/data.h | 153 +++-- include/treelite/entry.h | 20 - include/treelite/predictor.h | 214 +++--- include/treelite/typeinfo.h | 163 +++++ src/CMakeLists.txt | 11 +- src/c_api/c_api_common.cc | 63 +- src/c_api/c_api_runtime.cc | 175 ++--- src/predictor/predictor.cc | 640 +++++++++--------- src/predictor/thread_pool/spsc_queue.h | 6 + src/predictor/thread_pool/thread_pool.h | 10 +- src/typeinfo.cc | 22 + 15 files changed, 1011 insertions(+), 722 deletions(-) rename {src/c_api => include/treelite}/c_api_error.h (92%) delete mode 100644 include/treelite/entry.h create mode 100644 include/treelite/typeinfo.h create mode 100644 src/typeinfo.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index e44a113e..7c766b2d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,8 @@ set (CMAKE_FIND_NO_INSTALL_PREFIX TRUE FORCE) cmake_minimum_required (VERSION 3.14) project(treelite LANGUAGES CXX C VERSION 0.93) +set(CMAKE_CXX_STANDARD 14) + # check MSVC version if(MSVC) if(MSVC_VERSION LESS 1910) diff --git a/include/treelite/c_api_common.h b/include/treelite/c_api_common.h index c598309d..61ae5bdb 100644 --- a/include/treelite/c_api_common.h +++ b/include/treelite/c_api_common.h @@ -26,6 +26,9 @@ #define TREELITE_DLL TREELITE_EXTERN_C #endif +/*! \brief handle to a data matrix */ +typedef void* DMatrixHandle; + /*! * \brief display last error; can be called by multiple threads * Note. Each thread will get the last error occured in its own context. @@ -42,4 +45,69 @@ TREELITE_DLL const char* TreeliteGetLastError(void); */ TREELITE_DLL int TreeliteRegisterLogCallback(void (*callback)(const char*)); +/*! + * \defgroup dmatrix + * Data matrix interface + * \{ + */ +/*! + * \brief create a sparse DMatrix from a file + * \param path file path + * \param format file format + * \param nthread number of threads to use + * \param verbose whether to produce extra messages + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromFile( + const char* path, const char* format, const char* data_type, int nthread, int verbose, + DMatrixHandle* out); +/*! + * \brief create DMatrix from a (in-memory) CSR matrix + * \param data feature values + * \param data_type Type of data elements + * \param col_ind feature indices + * \param row_ptr pointer to row headers + * \param num_row number of rows + * \param num_col number of columns + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromCSR( + const void* data, const char* data_type, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, DMatrixHandle* out); +/*! + * \brief create DMatrix from a (in-memory) dense matrix + * \param data feature values + * \param data_type Type of data elements + * \param num_row number of rows + * \param num_col number of columns + * \param missing_value value to represent missing value + * \param out the created DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixCreateFromMat( + const void* data, const char* data_type, size_t num_row, size_t num_col, + const void* missing_value, DMatrixHandle* out); +/*! + * \brief get dimensions of a DMatrix + * \param handle handle to DMatrix + * \param out_num_row used to set number of rows + * \param out_num_col used to set number of columns + * \param out_nelem used to set number of nonzero entries + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixGetDimension(DMatrixHandle handle, + size_t* out_num_row, + size_t* out_num_col, + size_t* out_nelem); + +/*! + * \brief delete DMatrix from memory + * \param handle handle to DMatrix + * \return 0 for success, -1 for failure + */ +TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle); +/*! \} */ + #endif // TREELITE_C_API_COMMON_H_ diff --git a/src/c_api/c_api_error.h b/include/treelite/c_api_error.h similarity index 92% rename from src/c_api/c_api_error.h rename to include/treelite/c_api_error.h index 6f6b6108..28abd959 100644 --- a/src/c_api/c_api_error.h +++ b/include/treelite/c_api_error.h @@ -4,8 +4,8 @@ * \author Hyunsu Cho * \brief Error handling for C API. */ -#ifndef TREELITE_C_API_C_API_ERROR_H_ -#define TREELITE_C_API_C_API_ERROR_H_ +#ifndef TREELITE_C_API_ERROR_H_ +#define TREELITE_C_API_ERROR_H_ #include #include @@ -46,4 +46,4 @@ inline int TreeliteAPIHandleException(const std::exception &e) { TreeliteAPISetLastError(e.what()); return -1; } -#endif // TREELITE_C_API_C_API_ERROR_H_ +#endif // TREELITE_C_API_ERROR_H_ diff --git a/include/treelite/c_api_runtime.h b/include/treelite/c_api_runtime.h index cbd268ad..162af056 100644 --- a/include/treelite/c_api_runtime.h +++ b/include/treelite/c_api_runtime.h @@ -13,7 +13,6 @@ #define TREELITE_C_API_RUNTIME_H_ #include "c_api_common.h" -#include "entry.h" /*! * \addtogroup opaque_handles @@ -22,10 +21,9 @@ */ /*! \brief handle to predictor class */ typedef void* PredictorHandle; -/*! \brief handle to batch of sparse data rows */ -typedef void* CSRBatchHandle; -/*! \brief handle to batch of dense data rows */ -typedef void* DenseBatchHandle; +/*! \brief handle to output from predictor */ +typedef void* PredictorOutputHandle; + /*! \} */ /*! @@ -33,59 +31,6 @@ typedef void* DenseBatchHandle; * Predictor interface * \{ */ -/*! - * \brief assemble a sparse batch - * \param data feature values - * \param col_ind feature indices - * \param row_ptr pointer to row headers - * \param num_row number of data rows in the batch - * \param num_col number of columns (features) in the batch - * \param out handle to sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteAssembleSparseBatch(const float* data, - const uint32_t* col_ind, - const size_t* row_ptr, - size_t num_row, size_t num_col, - CSRBatchHandle* out); -/*! - * \brief delete a sparse batch from memory - * \param handle sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDeleteSparseBatch(CSRBatchHandle handle); -/*! - * \brief assemble a dense batch - * \param data feature values - * \param missing_value value to represent the missing value - * \param num_row number of data rows in the batch - * \param num_col number of columns (features) in the batch - * \param out handle to sparse batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteAssembleDenseBatch(const float* data, - float missing_value, - size_t num_row, size_t num_col, - DenseBatchHandle* out); -/*! - * \brief delete a dense batch from memory - * \param handle dense batch - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteDeleteDenseBatch(DenseBatchHandle handle); - -/*! - * \brief get dimensions of a batch - * \param handle a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether the batch is sparse (true) or dense (false) - * \param out_num_row used to set number of rows - * \param out_num_col used to set number of columns - * \return 0 for success, -1 for failure - */ -TREELITE_DLL int TreeliteBatchGetDimension(void* handle, - int batch_sparse, - size_t* out_num_row, - size_t* out_num_col); /*! * \brief load prediction code into memory. @@ -96,75 +41,82 @@ TREELITE_DLL int TreeliteBatchGetDimension(void* handle, * \param out handle to predictor * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorLoad(const char* library_path, - int num_worker_thread, - PredictorHandle* out); -/*! - * \brief Make predictions on a batch of data rows (synchronously). This - * function internally divides the workload among all worker threads. +TREELITE_DLL int TreelitePredictorLoad( + const char* library_path, int num_worker_thread, PredictorHandle* out); +/*! + * \brief Make predictions on a batch of data rows (synchronously). This function internally + * divides the workload among all worker threads. + * + * Note. This function does not allocate the result vector. Use + * TreeliteCreatePredictorOutputVector() convenience function to allocate the vector of + * the right length and type. + * + * Note. To access the element values from the output vector, you should cast the opaque + * handle (PredictorOutputHandle type) to an appropriate pointer LeafOutputType*, where + * the type is either float, double, or uint32_t. So carry out the following steps: + * 1. Call TreelitePredictorQueryLeafOutputType() to obtain the type of the leaf output. + * It will return a string ("float32", "float64", or "uint32") representing the type. + * 2. Depending on the type string, cast the output handle to float*, double*, or uint32_t*. + * 3. Now access the array with the casted pointer. The array's length is given by + * TreelitePredictorQueryResultSize(). * \param handle predictor - * \param batch a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether batch is sparse (1) or dense (0) + * \param batch the data matrix containing a batch of rows * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param out_result resulting output vector; use - * TreelitePredictorQueryResultSize() to allocate sufficient - * space + * \param out_result Resulting output vector. This pointer must point to an array of length + * TreelitePredictorQueryResultSize() and of type + * TreelitePredictorQueryLeafOutputType(). * \param out_result_size used to save length of the output vector, * which is guaranteed to be less than or equal to * TreelitePredictorQueryResultSize() * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorPredictBatch(PredictorHandle handle, - void* batch, - int batch_sparse, - int verbose, - int pred_margin, - float* out_result, - size_t* out_result_size); - -/*! - * \brief Make predictions on a single data row (synchronously). The work - * will be scheduled to the calling thread. +TREELITE_DLL int TreelitePredictorPredictBatch( + PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, + PredictorOutputHandle out_result, size_t* out_result_size); +/*! + * \brief Convenience function to allocate an output vector that is able to hold the prediction + * result for a given data matrix. The vector's length will be identical to + * TreelitePredictorQueryResultSize() and its type will be identical to + * TreelitePredictorQueryLeafOutputType(). To prevent memory leak, make sure to de-allocate + * the vector with TreeliteDeletePredictorOutputVector(). + * + * Note. To access the element values from the output vector, you should cast the opaque + * handle (PredictorOutputHandle type) to an appropriate pointer LeafOutputType*, where + * the type is either float, double, or uint32_t. So carry out the following steps: + * 1. Call TreelitePredictorQueryLeafOutputType() to obtain the type of the leaf output. + * It will return a string ("float32", "float64", or "uint32") representing the type. + * 2. Depending on the type string, cast the output handle to float*, double*, or uint32_t*. + * 3. Now access the array with the casted pointer. The array's length is given by + * TreelitePredictorQueryResultSize(). * \param handle predictor - * \param inst single data row - * \param pred_margin whether to produce raw margin scores instead of - * transformed probabilities - * \param out_result resulting output vector; use - * TreelitePredictorQueryResultSizeSingleInst() to allocate sufficient space - * \param out_result_size used to save length of the output vector, which is - * guaranteed to be at most TreelitePredictorQueryResultSizeSingleInst() + * \param batch the data matrix containing a batch of rows + * \param out_output_vector Handle to the newly allocated output vector. * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorPredictInst(PredictorHandle handle, - union TreelitePredictorEntry* inst, - int pred_margin, float* out_result, - size_t* out_result_size); +TREELITE_DLL int TreeliteCreatePredictorOutputVector( + PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_vector); /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. + * \brief De-allocate an output vector * \param handle predictor - * \param batch a batch of rows (must be of type SparseBatch or DenseBatch) - * \param batch_sparse whether batch is sparse (1) or dense (0) - * \param out used to store the length of prediction array + * \param output_vector Output vector to delete from memory * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryResultSize(PredictorHandle handle, - void* batch, - int batch_sparse, - size_t* out); +TREELITE_DLL int TreeliteDeletePredictorOutputVector( + PredictorHandle handle, PredictorOutputHandle output_vector); + /*! - * \brief Query the necessary size of array to hold the prediction for a - * single data row + * \brief Given a batch of data rows, query the necessary size of array to + * hold predictions for all data points. * \param handle predictor + * \param batch the data matrix containing a batch of rows * \param out used to store the length of prediction array * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryResultSizeSingleInst( - PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryResultSize( + PredictorHandle handle, DMatrixHandle batch, size_t* out); /*! * \brief Get the number of output groups in the loaded model * The number is 1 for most tasks; @@ -173,8 +125,7 @@ TREELITE_DLL int TreelitePredictorQueryResultSizeSingleInst( * \param out length of prediction array * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t* out); /*! * \brief Get the width (number of features) of each instance used to train * the loaded model @@ -182,8 +133,7 @@ TREELITE_DLL int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, * \param out number of features * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, - size_t* out); +TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t* out); /*! * \brief Get name of post prediction transformation used to train @@ -192,8 +142,7 @@ TREELITE_DLL int TreelitePredictorQueryNumFeature(PredictorHandle handle, * \param out name of post prediction transformation * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, - const char** out); +TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out); /*! * \brief Get alpha value of sigmoid transformation used to train * the loaded model @@ -201,8 +150,7 @@ TREELITE_DLL int TreelitePredictorQueryPredTransform(PredictorHandle handle, * \param out alpha value of sigmoid transformation * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, - float* out); +TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float* out); /*! * \brief Get global bias which adjusting predicted margin scores @@ -210,8 +158,10 @@ TREELITE_DLL int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, * \param out global bias value * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreelitePredictorQueryGlobalBias(PredictorHandle handle, - float* out); +TREELITE_DLL int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out); + +TREELITE_DLL int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out); +TREELITE_DLL int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out); /*! * \brief delete predictor from memory * \param handle predictor to remove diff --git a/include/treelite/data.h b/include/treelite/data.h index 88234ea9..6ca682b2 100644 --- a/include/treelite/data.h +++ b/include/treelite/data.h @@ -8,56 +8,129 @@ #define TREELITE_DATA_H_ #include +#include #include +#include +#include namespace treelite { -/*! \brief a simple data matrix in CSR (Compressed Sparse Row) storage */ -struct DMatrix { +enum class DMatrixType : uint8_t { + kDense = 0, + kSparseCSR = 1 +}; + +class DMatrix { + public: + virtual size_t GetNumRow() const = 0; + virtual size_t GetNumCol() const = 0; + virtual size_t GetNumElem() const = 0; + virtual DMatrixType GetType() const = 0; + virtual TypeInfo GetElementType() const = 0; + DMatrix() = default; + virtual ~DMatrix() = default; +}; + +class DenseDMatrix : public DMatrix { + private: + TypeInfo element_type_; + public: + template + static std::unique_ptr Create( + std::vector data, ElementType missing_value, size_t num_row, size_t num_col); + template + static std::unique_ptr Create( + const void* data, const void* missing_value, size_t num_row, size_t num_col); + static std::unique_ptr Create( + TypeInfo type, const void* data, const void* missing_value, size_t num_row, size_t num_col); + size_t GetNumRow() const override = 0; + size_t GetNumCol() const override = 0; + size_t GetNumElem() const override = 0; + DMatrixType GetType() const override = 0; + TypeInfo GetElementType() const override; +}; + +template +class DenseDMatrixImpl : public DenseDMatrix { + public: + /*! \brief feature values */ + std::vector data; + /*! \brief value representing the missing value (usually NaN) */ + ElementType missing_value; + /*! \brief number of rows */ + size_t num_row; + /*! \brief number of columns (i.e. # of features used) */ + size_t num_col; + + DenseDMatrixImpl() = delete; + DenseDMatrixImpl(std::vector data, ElementType missing_value, size_t num_row, + size_t num_col); + ~DenseDMatrixImpl() = default; + DenseDMatrixImpl(const DenseDMatrixImpl&) = default; + DenseDMatrixImpl(DenseDMatrixImpl&&) noexcept = default; + DenseDMatrixImpl& operator=(const DenseDMatrixImpl&) = default; + DenseDMatrixImpl& operator=(DenseDMatrixImpl&&) noexcept = default; + + size_t GetNumRow() const override; + size_t GetNumCol() const override; + size_t GetNumElem() const override; + DMatrixType GetType() const override; + + friend class DenseDMatrix; +}; + +class CSRDMatrix : public DMatrix { + private: + TypeInfo element_type_; + public: + template + static std::unique_ptr Create( + std::vector data, std::vector col_ind, std::vector row_ptr, + size_t num_row, size_t num_col); + template + static std::unique_ptr Create( + const void* data, const uint32_t* col_ind, const size_t* row_ptr, size_t num_row, + size_t num_col); + static std::unique_ptr Create( + TypeInfo type, const void* data, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col); + static std::unique_ptr Create( + const char* filename, const char* format, const char* data_type, int nthread, int verbose); + size_t GetNumRow() const override = 0; + size_t GetNumCol() const override = 0; + size_t GetNumElem() const override = 0; + DMatrixType GetType() const override = 0; + TypeInfo GetElementType() const override; +}; + +template +class CSRDMatrixImpl : public CSRDMatrix { + public: /*! \brief feature values */ - std::vector data; - /*! \brief feature indices */ + std::vector data; + /*! \brief feature indices. col_ind[i] indicates the feature index associated with data[i]. */ std::vector col_ind; - /*! \brief pointer to row headers; length of [num_row] + 1 */ + /*! \brief pointer to row headers; length is [num_row] + 1. */ std::vector row_ptr; /*! \brief number of rows */ size_t num_row; - /*! \brief number of columns */ + /*! \brief number of columns (i.e. # of features used) */ size_t num_col; - /*! \brief number of nonzero entries */ - size_t nelem; - - /*! - * \brief clear all data fields - */ - inline void Clear() { - data.clear(); - row_ptr.clear(); - col_ind.clear(); - row_ptr.resize(1, 0); - num_row = num_col = nelem = 0; - } - /*! - * \brief construct a new DMatrix from a file - * \param filename name of file - * \param format format of file (libsvm/libfm/csv) - * \param nthread number of threads to use - * \param verbose whether to produce extra messages - * \return newly built DMatrix - */ - static DMatrix* Create(const char* filename, const char* format, - int nthread, int verbose); - /*! - * \brief construct a new DMatrix from a data parser. The data parser here - * refers to any iterable object that streams input data in small - * batches. - * \param parser pointer to data parser - * \param nthread number of threads to use - * \param verbose whether to produce extra messages - * \return newly built DMatrix - */ - static DMatrix* Create(dmlc::Parser* parser, - int nthread, int verbose); + + CSRDMatrixImpl() = delete; + CSRDMatrixImpl(std::vector data, std::vector col_ind, + std::vector row_ptr, size_t num_row, size_t num_col); + CSRDMatrixImpl(const CSRDMatrixImpl&) = default; + CSRDMatrixImpl(CSRDMatrixImpl&&) noexcept = default; + CSRDMatrixImpl& operator=(const CSRDMatrixImpl&) = default; + CSRDMatrixImpl& operator=(CSRDMatrixImpl&&) noexcept = default; + + size_t GetNumRow() const override; + size_t GetNumCol() const override; + size_t GetNumElem() const override; + DMatrixType GetType() const override; + + friend class CSRDMatrix; }; } // namespace treelite diff --git a/include/treelite/entry.h b/include/treelite/entry.h deleted file mode 100644 index 8a3fb0a5..00000000 --- a/include/treelite/entry.h +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2017-2020 by Contributors - * \file entry.h - * \author Hyunsu Cho - * \brief Entry type for Treelite predictor - */ -#ifndef TREELITE_ENTRY_H_ -#define TREELITE_ENTRY_H_ - -/*! \brief data layout. The value -1 signifies the missing value. - When the "missing" field is set to -1, the "fvalue" field is set to - NaN (Not a Number), so there is no danger for mistaking between - missing values and non-missing values. */ -union TreelitePredictorEntry { - int missing; - float fvalue; - // may contain extra fields later, such as qvalue -}; - -#endif // TREELITE_ENTRY_H_ diff --git a/include/treelite/predictor.h b/include/treelite/predictor.h index 00222298..af3c8cff 100644 --- a/include/treelite/predictor.h +++ b/include/treelite/predictor.h @@ -8,151 +8,126 @@ #define TREELITE_PREDICTOR_H_ #include -#include +#include +#include +#include #include +#include #include namespace treelite { +namespace predictor { + +/*! \brief data layout. The value -1 signifies the missing value. + When the "missing" field is set to -1, the "fvalue" field is set to + NaN (Not a Number), so there is no danger for mistaking between + missing values and non-missing values. */ +template +union Entry { + int missing; + ElementType fvalue; + // may contain extra fields later, such as qvalue +}; + +class SharedLibrary { + public: + using LibraryHandle = void*; + using FunctionHandle = void*; + SharedLibrary(); + ~SharedLibrary(); + void Load(const char* libpath); + FunctionHandle LoadFunction(const char* name) const; + template + HandleType LoadFunctionWithSignature(const char* name) const; -/*! \brief sparse batch in Compressed Sparse Row (CSR) format */ -struct CSRBatch { - /*! \brief feature values */ - const float* data; - /*! \brief feature indices */ - const uint32_t* col_ind; - /*! \brief pointer to row headers; length of [num_row] + 1 */ - const size_t* row_ptr; - /*! \brief number of rows */ - size_t num_row; - /*! \brief number of columns (i.e. # of features used) */ - size_t num_col; + private: + LibraryHandle handle_; + std::string libpath_; }; -/*! \brief dense batch */ -struct DenseBatch { - /*! \brief feature values */ - const float* data; - /*! \brief value representing the missing value (usually nan) */ - float missing_value; - /*! \brief number of rows */ - size_t num_row; - /*! \brief number of columns (i.e. # of features used) */ - size_t num_col; +class PredFunction { + public: + static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type, + const SharedLibrary& library, int num_feature, + int num_output_group); + PredFunction() = default; + virtual ~PredFunction() = default; + virtual TypeInfo GetThresholdType() const = 0; + virtual TypeInfo GetLeafOutputType() const = 0; + virtual size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutputHandle out_pred) const = 0; +}; + +template +class PredFunctionImpl : public PredFunction { + public: + using PredFuncHandle = void*; + PredFunctionImpl(const SharedLibrary& library, int num_feature, int num_output_group); + TypeInfo GetThresholdType() const override; + TypeInfo GetLeafOutputType() const override; + size_t PredictBatch(const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutputHandle out_pred) const override; + + private: + PredFuncHandle handle_; + int num_feature_; + int num_output_group_; }; /*! \brief predictor class: wrapper for optimized prediction code */ class Predictor { public: /*! \brief opaque handle types */ - typedef void* QueryFuncHandle; - typedef void* PredFuncHandle; - typedef void* LibraryHandle; typedef void* ThreadPoolHandle; explicit Predictor(int num_worker_thread = -1); ~Predictor(); /*! * \brief load the prediction function from dynamic shared library. - * \param name name of dynamic shared library (.so/.dll/.dylib). + * \param libpath path of dynamic shared library (.so/.dll/.dylib). */ - void Load(const char* name); + void Load(const char* libpath); /*! * \brief unload the prediction function */ void Free(); - /*! * \brief Make predictions on a batch of data rows (synchronously). This * function internally divides the workload among all worker threads. - * \param batch a batch of rows + * \param dmat a batch of rows * \param verbose whether to produce extra messages * \param pred_margin whether to produce raw margin scores instead of * transformed probabilities - * \param out_result resulting output vector; use - * QueryResultSize() to allocate sufficient space + * \param out_result Resulting output vector. This pointer must point to an array of length + * QueryResultSize() and of type QueryLeafOutputType(). * \return length of the output vector, which is guaranteed to be less than * or equal to QueryResultSize() */ - size_t PredictBatch(const CSRBatch* batch, int verbose, - bool pred_margin, float* out_result); - size_t PredictBatch(const DenseBatch* batch, int verbose, - bool pred_margin, float* out_result); - /*! - * \brief Make predictions on a single data row (synchronously). The work - * will be scheduled to the calling thread. - * \param inst single data row - * \param pred_margin whether to produce raw margin scores instead of - * transformed probabilities - * \param out_result resulting output vector; use - * QueryResultSizeSingleInst() to allocate sufficient space - * \return length of the output vector, which is guaranteed to be less than - * or equal to QueryResultSizeSingleInst() - */ - size_t PredictInst(TreelitePredictorEntry* inst, bool pred_margin, - float* out_result); - + size_t PredictBatch( + const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const; /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. - * \param batch a batch of rows + * \param dmat a batch of rows * \return length of prediction array */ - inline size_t QueryResultSize(const CSRBatch* batch) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return batch->num_row * num_output_group_; + inline size_t QueryResultSize(const DMatrix* dmat) const { + CHECK(pred_func_) << "A shared library needs to be loaded first using Load()"; + return dmat->GetNumRow() * num_output_group_; } /*! * \brief Given a batch of data rows, query the necessary size of array to * hold predictions for all data points. - * \param batch a batch of rows - * \return length of prediction array - */ - inline size_t QueryResultSize(const DenseBatch* batch) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return batch->num_row * num_output_group_; - } - /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. - * \param batch a batch of rows + * \param dmat a batch of rows * \param rbegin beginning of range of rows * \param rend end of range of rows * \return length of prediction array */ - inline size_t QueryResultSize(const CSRBatch* batch, - size_t rbegin, size_t rend) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - CHECK(rbegin < rend && rend <= batch->num_row); + inline size_t QueryResultSize(const DMatrix* dmat, size_t rbegin, size_t rend) const { + CHECK(pred_func_) << "A shared library needs to be loaded first using Load()"; + CHECK(rbegin < rend && rend <= dmat->GetNumRow()); return (rend - rbegin) * num_output_group_; } - /*! - * \brief Given a batch of data rows, query the necessary size of array to - * hold predictions for all data points. - * \param batch a batch of rows - * \param rbegin beginning of range of rows - * \param rend end of range of rows - * \return length of prediction array - */ - inline size_t QueryResultSize(const DenseBatch* batch, - size_t rbegin, size_t rend) const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - CHECK(rbegin < rend && rend <= batch->num_row); - return (rend - rbegin) * num_output_group_; - } - /*! - * \brief Query the necessary size of array to hold the prediction for a - * single data row - * \return length of prediction array - */ - inline size_t QueryResultSizeSingleInst() const { - CHECK(pred_func_handle_ != nullptr) - << "A shared library needs to be loaded first using Load()"; - return num_output_group_; - } /*! * \brief Get the number of output groups in the loaded model * The number is 1 for most tasks; @@ -162,7 +137,6 @@ class Predictor { inline size_t QueryNumOutputGroup() const { return num_output_group_; } - /*! * \brief Get the width (number of features) of each instance used to train * the loaded model @@ -171,7 +145,6 @@ class Predictor { inline size_t QueryNumFeature() const { return num_feature_; } - /*! * \brief Get name of post prediction transformation used to train the loaded model * \return name of prediction transformation @@ -179,7 +152,6 @@ class Predictor { inline std::string QueryPredTransform() const { return pred_transform_; } - /*! * \brief Get alpha value in sigmoid transformation used to train the loaded model * \return alpha value in sigmoid transformation @@ -187,7 +159,6 @@ class Predictor { inline float QuerySigmoidAlpha() const { return sigmoid_alpha_; } - /*! * \brief Get global bias which adjusting predicted margin scores * \return global bias @@ -195,15 +166,35 @@ class Predictor { inline float QueryGlobalBias() const { return global_bias_; } + /*! + * \brief Get the type of the split thresholds + * \return type of the split thresholds + */ + inline TypeInfo QueryThresholdType() const { + return threshold_type_; + } + /*! + * \brief Create an output vector suitable to hold prediction result for a given data matrix + * \param dmat a data matrix + * \return Opaque handle to the allocated output vector + */ + PredictorOutputHandle CreateOutputVector(const DMatrix* dmat) const; + /*! + * \brief Free an output vector from memory + * \param output_vector Opaque handle to the output vector + */ + void DeleteOutputVector(PredictorOutputHandle output_vector) const; + /*! + * \brief Get the type of the leaf outputs + * \return type of the leaf outputs + */ + inline TypeInfo QueryLeafOutputType() const { + return leaf_output_type_; + } private: - LibraryHandle lib_handle_; - QueryFuncHandle num_output_group_query_func_handle_; - QueryFuncHandle num_feature_query_func_handle_; - QueryFuncHandle pred_transform_query_func_handle_; - QueryFuncHandle sigmoid_alpha_query_func_handle_; - QueryFuncHandle global_bias_query_func_handle_; - PredFuncHandle pred_func_handle_; + SharedLibrary lib_; + std::unique_ptr pred_func_; ThreadPoolHandle thread_pool_handle_; size_t num_output_group_; size_t num_feature_; @@ -211,12 +202,13 @@ class Predictor { float sigmoid_alpha_; float global_bias_; int num_worker_thread_; + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; - template - size_t PredictBatchBase_(const BatchType* batch, int verbose, - bool pred_margin, float* out_result); + mutable dmlc::OMPException exception_catcher_; }; +} // namespace predictor } // namespace treelite #endif // TREELITE_PREDICTOR_H_ diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h new file mode 100644 index 00000000..7b1337ea --- /dev/null +++ b/include/treelite/typeinfo.h @@ -0,0 +1,163 @@ +/*! + * Copyright (c) 2017-2020 by Contributors + * \file typeinfo.h + * \brief Defines TypeInfo class and utilities + * \author Hyunsu Cho + */ + +#ifndef TREELITE_TYPEINFO_H_ +#define TREELITE_TYPEINFO_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace treelite { + +/*! \brief Types used by thresholds and leaf outputs */ +enum class TypeInfo : uint8_t { + kInvalid = 0, + kUInt32 = 1, + kFloat32 = 2, + kFloat64 = 3 +}; +static_assert(std::is_same::type, uint8_t>::value, + "TypeInfo must use uint8_t as underlying type"); + +/*! \brief conversion table from string to TypeInfo, defined in tables.cc */ +extern const std::unordered_map typeinfo_table; + +/*! + * \brief Get string representation of type info + * \param info a type info + * \return string representation + */ +inline std::string TypeInfoToString(treelite::TypeInfo type) { + switch (type) { + case treelite::TypeInfo::kInvalid: + return "invalid"; + case treelite::TypeInfo::kUInt32: + return "uint32"; + case treelite::TypeInfo::kFloat32: + return "float32"; + case treelite::TypeInfo::kFloat64: + return "float64"; + default: + throw std::runtime_error("Unrecognized type"); + return ""; + } +} + +/*! + * \brief Convert a template type into a type info + * \tparam template type to be converted + * \return TypeInfo corresponding to the template type arg + */ +template +inline TypeInfo InferTypeInfoOf() { + if (std::is_same::value) { + return TypeInfo::kUInt32; + } else if (std::is_same::value) { + return TypeInfo::kFloat32; + } else if (std::is_same::value) { + return TypeInfo::kFloat64; + } else { + throw std::runtime_error(std::string("Unrecognized Value type") + typeid(T).name()); + return TypeInfo::kInvalid; + } +} + +/*! + * \brief Given a TypeInfo, dispatch a function with the corresponding template arg. More precisely, + * we shall call Dispatcher::Dispatch() where the template arg T corresponds to the + * `type` parameter. + * \tparam Dispatcher Function object that takes in one template arg. + * It must have a Dispatch() static function. + * \tparam Parameter pack, to forward an arbitrary number of args to Dispatcher::Dispatch() + * \param type TypeInfo corresponding to the template arg T with which + * Dispatcher::Dispatch() is called. + * \param args Other extra parameters to pass to Dispatcher::Dispatch() + * \return Whatever that's returned by the dispatcher + */ +template class Dispatcher, typename ...Args> +inline auto DispatchWithTypeInfo(TypeInfo type, Args&& ...args) { + switch (type) { + case TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kFloat32: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kFloat64: + return Dispatcher::Dispatch(std::forward(args)...); + case TypeInfo::kInvalid: + default: + throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type)); + } + return Dispatcher::Dispatch(std::forward(args)...); // avoid missing return error +} + +/*! + * \brief Given the types for thresholds and leaf outputs, validate that they consist of a valid + * combination for a model and then dispatch a function with the corresponding template args. + * More precisely, we shall call Dispatcher::Dispatch() where + * the template args ThresholdType and LeafOutputType correspond to the parameters + * `threshold_type` and `leaf_output_type`, respectively. + * \tparam Dispatcher Function object that takes in two template args. + * It must have a Dispatch() static function. + * \tparam Parameter pack, to forward an arbitrary number of args to Dispatcher::Dispatch() + * \param threshold_type TypeInfo indicating the type of thresholds + * \param leaf_output_type TypeInfo indicating the type of leaf outputs + * \param args Other extra parameters to pass to Dispatcher::Dispatch() + * \return Whatever that's returned by the dispatcher + */ +template class Dispatcher, typename ...Args> +inline auto DispatchWithModelTypes( + TypeInfo threshold_type, TypeInfo leaf_output_type, Args&& ...args) { + auto error_threshold_type = [threshold_type]() { + std::ostringstream oss; + oss << "Invalid threshold type: " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + auto error_leaf_output_type = [threshold_type, leaf_output_type]() { + std::ostringstream oss; + oss << "Cannot use leaf output type " << treelite::TypeInfoToString(leaf_output_type) + << " with threshold type " << treelite::TypeInfoToString(threshold_type); + return oss.str(); + }; + switch (threshold_type) { + case treelite::TypeInfo::kFloat32: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case treelite::TypeInfo::kFloat32: + return Dispatcher::Dispatch(std::forward(args)...); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + case treelite::TypeInfo::kFloat64: + switch (leaf_output_type) { + case treelite::TypeInfo::kUInt32: + return Dispatcher::Dispatch(std::forward(args)...); + case treelite::TypeInfo::kFloat64: + return Dispatcher::Dispatch(std::forward(args)...); + default: + throw std::runtime_error(error_leaf_output_type()); + break; + } + break; + default: + throw std::runtime_error(error_threshold_type()); + break; + } + return Dispatcher::Dispatch(std::forward(args)...); + // avoid missing return value warning +} + +} // namespace treelite + +#endif // TREELITE_TYPEINFO_H_ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ea2e77b5..c56c8b92 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -50,7 +50,7 @@ endforeach() set_target_properties(objtreelite objtreelite_runtime objtreelite_common PROPERTIES POSITION_INDEPENDENT_CODE ON - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON) target_sources(objtreelite @@ -84,7 +84,6 @@ target_sources(objtreelite frontend/lightgbm.cc frontend/xgboost.cc annotator.cc - data.cc filesystem.cc optable.cc reference_serializer.cc @@ -93,7 +92,6 @@ target_sources(objtreelite ${PROJECT_SOURCE_DIR}/include/treelite/c_api.h ${PROJECT_SOURCE_DIR}/include/treelite/compiler.h ${PROJECT_SOURCE_DIR}/include/treelite/compiler_param.h - ${PROJECT_SOURCE_DIR}/include/treelite/data.h ${PROJECT_SOURCE_DIR}/include/treelite/filesystem.h ${PROJECT_SOURCE_DIR}/include/treelite/frontend.h ${PROJECT_SOURCE_DIR}/include/treelite/omp.h @@ -108,7 +106,6 @@ target_sources(objtreelite_runtime predictor/thread_pool/thread_pool.h predictor/predictor.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_runtime.h - ${PROJECT_SOURCE_DIR}/include/treelite/entry.h ${PROJECT_SOURCE_DIR}/include/treelite/predictor.h ) @@ -116,11 +113,15 @@ target_sources(objtreelite_common PRIVATE c_api/c_api_common.cc c_api/c_api_error.cc - c_api/c_api_error.h + data.cc logging.cc + typeinfo.cc ${PROJECT_SOURCE_DIR}/include/treelite/c_api_common.h + ${PROJECT_SOURCE_DIR}/include/treelite/c_api_error.h ${PROJECT_SOURCE_DIR}/include/treelite/logging.h ${PROJECT_SOURCE_DIR}/include/treelite/math.h + ${PROJECT_SOURCE_DIR}/include/treelite/typeinfo.h + ${PROJECT_SOURCE_DIR}/include/treelite/data.h ) msvc_use_static_runtime() diff --git a/src/c_api/c_api_common.cc b/src/c_api/c_api_common.cc index 040d02ef..76a36f59 100644 --- a/src/c_api/c_api_common.cc +++ b/src/c_api/c_api_common.cc @@ -5,15 +5,76 @@ * \brief C API of treelite (this file is used by both runtime and main package) */ +#include #include +#include #include -#include "./c_api_error.h" +#include using namespace treelite; +/*! \brief entry to to easily hold returning information */ +struct TreeliteAPIThreadLocalEntry { + /*! \brief result holder for returning string */ + std::string ret_str; +}; + +// define threadlocal store for returning information +using TreeliteAPIThreadLocalStore + = dmlc::ThreadLocalStore; + int TreeliteRegisterLogCallback(void (*callback)(const char*)) { API_BEGIN(); LogCallbackRegistry* registry = LogCallbackRegistryStore::Get(); registry->Register(callback); API_END(); } + +int TreeliteDMatrixCreateFromFile( + const char* path, const char* format, const char* data_type, int nthread, int verbose, + DMatrixHandle* out) { + API_BEGIN(); + std::unique_ptr mat = CSRDMatrix::Create(path, format, data_type, nthread, verbose); + *out = static_cast(mat.release()); + API_END(); +} + +int TreeliteDMatrixCreateFromCSR( + const void* data, const char* data_type_str, const uint32_t* col_ind, const size_t* row_ptr, + size_t num_row, size_t num_col, DMatrixHandle* out) { + API_BEGIN(); + TypeInfo data_type = typeinfo_table.at(data_type_str); + std::unique_ptr matrix + = CSRDMatrix::Create(data_type, data, col_ind, row_ptr, num_row, num_col); + *out = static_cast(matrix.release()); + API_END(); +} + +int TreeliteDMatrixCreateFromMat( + const void* data, const char* data_type_str, size_t num_row, size_t num_col, + const void* missing_value, DMatrixHandle* out) { + API_BEGIN(); + TypeInfo data_type = typeinfo_table.at(data_type_str); + std::unique_ptr matrix + = DenseDMatrix::Create(data_type, data, missing_value, num_row, num_col); + *out = static_cast(matrix.release()); + API_END(); +} + +int TreeliteDMatrixGetDimension(DMatrixHandle handle, + size_t* out_num_row, + size_t* out_num_col, + size_t* out_nelem) { + API_BEGIN(); + const DMatrix* dmat = static_cast(handle); + *out_num_row = dmat->GetNumRow(); + *out_num_col = dmat->GetNumCol(); + *out_nelem = dmat->GetNumElem(); + API_END(); +} + +int TreeliteDMatrixFree(DMatrixHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} diff --git a/src/c_api/c_api_runtime.cc b/src/c_api/c_api_runtime.cc index 0cb55ac6..46c3bb87 100644 --- a/src/c_api/c_api_runtime.cc +++ b/src/c_api/c_api_runtime.cc @@ -7,10 +7,10 @@ #include #include +#include #include #include #include -#include "./c_api_error.h" using namespace treelite; @@ -28,155 +28,72 @@ using TreeliteRuntimeAPIThreadLocalStore } // anonymous namespace -int TreeliteAssembleSparseBatch(const float* data, - const uint32_t* col_ind, - const size_t* row_ptr, - size_t num_row, size_t num_col, - CSRBatchHandle* out) { +int TreelitePredictorLoad(const char* library_path, int num_worker_thread, PredictorHandle* out) { API_BEGIN(); - CSRBatch* batch = new CSRBatch(); - batch->data = data; - batch->col_ind = col_ind; - batch->row_ptr = row_ptr; - batch->num_row = num_row; - batch->num_col = num_col; - *out = static_cast(batch); - API_END(); -} - -int TreeliteDeleteSparseBatch(CSRBatchHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TreeliteAssembleDenseBatch(const float* data, float missing_value, - size_t num_row, size_t num_col, - DenseBatchHandle* out) { - API_BEGIN(); - DenseBatch* batch = new DenseBatch(); - batch->data = data; - batch->missing_value = missing_value; - batch->num_row = num_row; - batch->num_col = num_col; - *out = static_cast(batch); - API_END(); -} - -int TreeliteDeleteDenseBatch(DenseBatchHandle handle) { - API_BEGIN(); - delete static_cast(handle); - API_END(); -} - -int TreeliteBatchGetDimension(void* handle, - int batch_sparse, - size_t* out_num_row, - size_t* out_num_col) { - API_BEGIN(); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(handle); - *out_num_row = batch_->num_row; - *out_num_col = batch_->num_col; - } else { - const DenseBatch* batch_ = static_cast(handle); - *out_num_row = batch_->num_row; - *out_num_col = batch_->num_col; - } - API_END(); -} - -int TreelitePredictorLoad(const char* library_path, - int num_worker_thread, - PredictorHandle* out) { - API_BEGIN(); - Predictor* predictor = new Predictor(num_worker_thread); + auto predictor = std::make_unique(num_worker_thread); predictor->Load(library_path); - *out = static_cast(predictor); + *out = static_cast(predictor.release()); API_END(); } -int TreelitePredictorPredictBatch(PredictorHandle handle, - void* batch, - int batch_sparse, - int verbose, - int pred_margin, - float* out_result, - size_t* out_result_size) { +int TreelitePredictorPredictBatch( + PredictorHandle handle, DMatrixHandle batch, int verbose, int pred_margin, + PredictorOutputHandle out_result, size_t* out_result_size) { API_BEGIN(); - Predictor* predictor_ = static_cast(handle); - const size_t num_feature = predictor_->QueryNumFeature(); + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + const size_t num_feature = predictor->QueryNumFeature(); const std::string err_msg = std::string("Too many columns (features) in the given batch. " - "Number of features must not exceed ") - + std::to_string(num_feature); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(batch); - CHECK_LE(batch_->num_col, num_feature) << err_msg; - *out_result_size = predictor_->PredictBatch(batch_, verbose, - (pred_margin != 0), out_result); - } else { - const DenseBatch* batch_ = static_cast(batch); - CHECK_LE(batch_->num_col, num_feature) << err_msg; - *out_result_size = predictor_->PredictBatch(batch_, verbose, - (pred_margin != 0), out_result); - } + "Number of features must not exceed ") + std::to_string(num_feature); + CHECK_LE(dmat->GetNumCol(), num_feature) << err_msg; + *out_result_size = predictor->PredictBatch(dmat, verbose, (pred_margin != 0), out_result); API_END(); } -int TreelitePredictorPredictInst(PredictorHandle handle, - union TreelitePredictorEntry* inst, - int pred_margin, - float* out_result, size_t* out_result_size) { +int TreeliteCreatePredictorOutputVector( + PredictorHandle handle, DMatrixHandle batch, PredictorOutputHandle* out_output_vector) { API_BEGIN(); - Predictor* predictor_ = static_cast(handle); - *out_result_size - = predictor_->PredictInst(inst, (pred_margin != 0), out_result); + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + *out_output_vector = predictor->CreateOutputVector(dmat); API_END(); } -int TreelitePredictorQueryResultSize(PredictorHandle handle, - void* batch, - int batch_sparse, - size_t* out) { +int TreeliteDeletePredictorOutputVector( + PredictorHandle handle, PredictorOutputHandle output_vector) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - if (batch_sparse) { - const CSRBatch* batch_ = static_cast(batch); - *out = predictor_->QueryResultSize(batch_); - } else { - const DenseBatch* batch_ = static_cast(batch); - *out = predictor_->QueryResultSize(batch_); - } + const auto* predictor = static_cast(handle); + predictor->DeleteOutputVector(output_vector); API_END(); } -int TreelitePredictorQueryResultSizeSingleInst(PredictorHandle handle, - size_t* out) { +int TreelitePredictorQueryResultSize(PredictorHandle handle, DMatrixHandle batch, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryResultSizeSingleInst(); + const auto* predictor = static_cast(handle); + const auto* dmat = static_cast(batch); + *out = predictor->QueryResultSize(dmat); API_END(); } int TreelitePredictorQueryNumOutputGroup(PredictorHandle handle, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryNumOutputGroup(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryNumOutputGroup(); API_END(); } int TreelitePredictorQueryNumFeature(PredictorHandle handle, size_t* out) { API_BEGIN(); - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryNumFeature(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryNumFeature(); API_END(); } int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - auto pred_transform = predictor_->QueryPredTransform(); + const auto* predictor = static_cast(handle); + auto pred_transform = predictor->QueryPredTransform(); std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; ret_str = pred_transform; *out = ret_str.c_str(); @@ -185,20 +102,38 @@ int TreelitePredictorQueryPredTransform(PredictorHandle handle, const char** out int TreelitePredictorQuerySigmoidAlpha(PredictorHandle handle, float* out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QuerySigmoidAlpha(); + const auto* predictor = static_cast(handle); + *out = predictor->QuerySigmoidAlpha(); API_END(); } int TreelitePredictorQueryGlobalBias(PredictorHandle handle, float* out) { API_BEGIN() - const Predictor* predictor_ = static_cast(handle); - *out = predictor_->QueryGlobalBias(); + const auto* predictor = static_cast(handle); + *out = predictor->QueryGlobalBias(); + API_END(); +} + +int TreelitePredictorQueryThresholdType(PredictorHandle handle, const char** out) { + API_BEGIN() + const auto* predictor = static_cast(handle); + std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; + ret_str = TypeInfoToString(predictor->QueryThresholdType()); + *out = ret_str.c_str(); + API_END(); +} + +int TreelitePredictorQueryLeafOutputType(PredictorHandle handle, const char** out) { + API_BEGIN() + const auto* predictor = static_cast(handle); + std::string& ret_str = TreeliteRuntimeAPIThreadLocalStore::Get()->ret_str; + ret_str = TypeInfoToString(predictor->QueryLeafOutputType()); + *out = ret_str.c_str(); API_END(); } int TreelitePredictorFree(PredictorHandle handle) { API_BEGIN(); - delete static_cast(handle); + delete static_cast(handle); API_END(); } diff --git a/src/predictor/predictor.cc b/src/predictor/predictor.cc index b8739386..a9fce4ad 100644 --- a/src/predictor/predictor.cc +++ b/src/predictor/predictor.cc @@ -7,6 +7,8 @@ #include #include +#include +#include #include #include #include @@ -27,82 +29,38 @@ namespace { -enum class InputType : uint8_t { - kSparseBatch = 0, kDenseBatch = 1 -}; - struct InputToken { - InputType input_type; - const void* data; // pointer to input data + const treelite::DMatrix* dmat; // input data bool pred_margin; // whether to store raw margin or transformed scores - size_t num_feature; - // # features (columns) accepted by the tree ensemble model - size_t num_output_group; - // size of output per instance (row) - treelite::Predictor::PredFuncHandle pred_func_handle; - size_t rbegin, rend; - // range of instances (rows) assigned to each worker - float* out_pred; - // buffer to store output from each worker + const treelite::predictor::PredFunction* pred_func_; + size_t rbegin, rend; // range of instances (rows) assigned to each worker + PredictorOutputHandle out_pred; // buffer to store output from each worker }; struct OutputToken { size_t query_result_size; }; -using PredThreadPool = treelite::ThreadPool; - -inline treelite::Predictor::LibraryHandle OpenLibrary(const char* name) { -#ifdef _WIN32 - HMODULE handle = LoadLibraryA(name); -#else - void* handle = dlopen(name, RTLD_LAZY | RTLD_LOCAL); -#endif - return static_cast(handle); -} - -inline void CloseLibrary(treelite::Predictor::LibraryHandle handle) { -#ifdef _WIN32 - FreeLibrary(static_cast(handle)); -#else - dlclose(static_cast(handle)); -#endif -} - -template -inline HandleType LoadFunction(treelite::Predictor::LibraryHandle lib_handle, - const char* name) { -#ifdef _WIN32 - FARPROC func_handle = GetProcAddress(static_cast(lib_handle), name); -#else - void* func_handle = dlsym(static_cast(lib_handle), name); -#endif - return static_cast(func_handle); -} - -template -inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature, - size_t rbegin, size_t rend, - float* out_pred, PredFunc func) { - CHECK_LE(batch->num_col, num_feature); - std::vector inst( - std::max(batch->num_col, num_feature), {-1}); - CHECK(rbegin < rend && rend <= batch->num_row); - CHECK(sizeof(size_t) < sizeof(int64_t) - || (rbegin <= static_cast(std::numeric_limits::max()) - && rend <= static_cast(std::numeric_limits::max()))); - const int64_t rbegin_ = static_cast(rbegin); - const int64_t rend_ = static_cast(rend); - const size_t num_col = batch->num_col; - const float* data = batch->data; - const uint32_t* col_ind = batch->col_ind; - const size_t* row_ptr = batch->row_ptr; +using PredThreadPool + = treelite::predictor::ThreadPool; + +template +inline size_t PredLoop(const treelite::CSRDMatrixImpl* dmat, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + CHECK_LE(dmat->num_col, static_cast(num_feature)); + std::vector> inst( + std::max(dmat->num_col, static_cast(num_feature)), {-1}); + CHECK(rbegin < rend && rend <= dmat->num_row); + const size_t num_col = dmat->num_col; + const ElementType* data = dmat->data.data(); + const uint32_t* col_ind = dmat->col_ind.data(); + const size_t* row_ptr = dmat->row_ptr.data(); size_t total_output_size = 0; - for (int64_t rid = rbegin_; rid < rend_; ++rid) { + for (size_t rid = rbegin; rid < rend; ++rid) { const size_t ibegin = row_ptr[rid]; const size_t iend = row_ptr[rid + 1]; for (size_t i = ibegin; i < iend; ++i) { - inst[col_ind[i]].fvalue = data[i]; + inst[col_ind[i]].fvalue = static_cast(data[i]); } total_output_size += func(rid, &inst[0], out_pred); for (size_t i = ibegin; i < iend; ++i) { @@ -112,34 +70,27 @@ inline size_t PredLoop(const treelite::CSRBatch* batch, size_t num_feature, return total_output_size; } -template -inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature, - size_t rbegin, size_t rend, - float* out_pred, PredFunc func) { - const bool nan_missing = treelite::math::CheckNAN(batch->missing_value); - CHECK_LE(batch->num_col, num_feature); - std::vector inst( - std::max(batch->num_col, num_feature), {-1}); - CHECK(rbegin < rend && rend <= batch->num_row); - CHECK(sizeof(size_t) < sizeof(int64_t) - || (rbegin <= static_cast(std::numeric_limits::max()) - && rend <= static_cast(std::numeric_limits::max()))); - const int64_t rbegin_ = static_cast(rbegin); - const int64_t rend_ = static_cast(rend); - const size_t num_col = batch->num_col; - const float missing_value = batch->missing_value; - const float* data = batch->data; - const float* row; +template +inline size_t PredLoop(const treelite::DenseDMatrixImpl* dmat, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + const bool nan_missing = treelite::math::CheckNAN(dmat->missing_value); + CHECK_LE(dmat->num_col, static_cast(num_feature)); + std::vector> inst( + std::max(dmat->num_col, static_cast(num_feature)), {-1}); + CHECK(rbegin < rend && rend <= dmat->num_row); + const size_t num_col = dmat->num_col; + const ElementType missing_value = dmat->missing_value; + const ElementType* data = dmat->data.data(); + const ElementType* row = nullptr; size_t total_output_size = 0; - for (int64_t rid = rbegin_; rid < rend_; ++rid) { + for (size_t rid = rbegin; rid < rend; ++rid) { row = &data[rid * num_col]; for (size_t j = 0; j < num_col; ++j) { if (treelite::math::CheckNAN(row[j])) { CHECK(nan_missing) - << "The missing_value argument must be set to NaN if there is any " - << "NaN in the matrix."; + << "The missing_value argument must be set to NaN if there is any NaN in the matrix."; } else if (nan_missing || row[j] != missing_value) { - inst[j].fvalue = row[j]; + inst[j].fvalue = static_cast(row[j]); } } total_output_size += func(rid, &inst[0], out_pred); @@ -150,219 +101,281 @@ inline size_t PredLoop(const treelite::DenseBatch* batch, size_t num_feature, return total_output_size; } -template -inline size_t PredictBatch_(const BatchType* batch, bool pred_margin, - size_t num_feature, size_t num_output_group, - treelite::Predictor::PredFuncHandle pred_func_handle, - size_t rbegin, size_t rend, - size_t expected_query_result_size, float* out_pred) { - CHECK(pred_func_handle != nullptr) - << "A shared library needs to be loaded first using Load()"; - /* Pass the correct prediction function to PredLoop. - We also need to specify how the function should be called. */ - size_t query_result_size; - // Dimension of output vector: - // can be either [num_data] or [num_class]*[num_data]. - // Note that size of prediction may be smaller than out_pred (this occurs - // when pred_function is set to "max_index"). - if (num_output_group > 1) { // multi-class classification task - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = - PredLoop(batch, num_feature, rbegin, rend, out_pred, - [pred_func, num_output_group, pred_margin] - (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { - return pred_func(inst, static_cast(pred_margin), - &out_pred[rid * num_output_group]); - }); - } else { // every other task - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = - PredLoop(batch, num_feature, rbegin, rend, out_pred, - [pred_func, pred_margin] - (int64_t rid, TreelitePredictorEntry* inst, float* out_pred) -> size_t { - out_pred[rid] = pred_func(inst, static_cast(pred_margin)); - return 1; - }); +template +class PredLoopDispatcherWithDenseDMatrix { + public: + template + inline static size_t Dispatch( + const treelite::DMatrix* dmat, ThresholdType test_val, + int num_feature, size_t rbegin, size_t rend, + LeafOutputType* out_pred, PredFunc func) { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop( + dmat_, num_feature, rbegin, rend, out_pred, func); } - return query_result_size; -} +}; -inline size_t PredictInst_(TreelitePredictorEntry* inst, - bool pred_margin, size_t num_output_group, - treelite::Predictor::PredFuncHandle pred_func_handle, - size_t expected_query_result_size, float* out_pred) { - CHECK(pred_func_handle != nullptr) - << "A shared library needs to be loaded first using Load()"; - size_t query_result_size; // Dimention of output vector - if (num_output_group > 1) { // multi-class classification task - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - query_result_size = pred_func(inst, static_cast(pred_margin), out_pred); - } else { // every other task - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle); - out_pred[0] = pred_func(inst, static_cast(pred_margin)); - query_result_size = 1; +template +class PredLoopDispatcherWithCSRDMatrix { + public: + template + inline static size_t Dispatch( + const treelite::DMatrix* dmat, ThresholdType test_val, + int num_feature, size_t rbegin, size_t rend, + LeafOutputType* out_pred, PredFunc func) { + const auto* dmat_ = static_cast*>(dmat); + return PredLoop( + dmat_, num_feature, rbegin, rend, out_pred, func); + } +}; + +template +inline size_t PredLoop(const treelite::DMatrix* dmat, ThresholdType test_val, int num_feature, + size_t rbegin, size_t rend, LeafOutputType* out_pred, PredFunc func) { + treelite::DMatrixType dmat_type = dmat->GetType(); + switch (dmat_type) { + case treelite::DMatrixType::kDense: { + return treelite::DispatchWithTypeInfo( + dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func); + } + case treelite::DMatrixType::kSparseCSR: { + return treelite::DispatchWithTypeInfo( + dmat->GetElementType(), dmat, test_val, num_feature, rbegin, rend, out_pred, func); + } + default: + LOG(FATAL) << "Unrecognized data matrix type: " << static_cast(dmat_type); + return 0; } - return query_result_size; } } // anonymous namespace namespace treelite { +namespace predictor { + +SharedLibrary::SharedLibrary() : handle_(nullptr), libpath_() {} + +SharedLibrary::~SharedLibrary() { + if (handle_) { +#ifdef _WIN32 + FreeLibrary(static_cast(handle_)); +#else + dlclose(static_cast(handle_)); +#endif + } +} + +void +SharedLibrary::Load(const char* libpath) { +#ifdef _WIN32 + HMODULE handle = LoadLibraryA(libpath); +#else + void* handle = dlopen(libpath, RTLD_LAZY | RTLD_LOCAL); +#endif + CHECK(handle) << "Failed to load dynamic shared library `" << libpath << "'"; + handle_ = static_cast(handle); + libpath_ = std::string(libpath); +} + +SharedLibrary::FunctionHandle +SharedLibrary::LoadFunction(const char* name) const { +#ifdef _WIN32 + FARPROC func_handle = GetProcAddress(static_cast(handle_), name); +#else + void* func_handle = dlsym(static_cast(handle_), name); +#endif + CHECK(func_handle) + << "Dynamic shared library `" << libpath_ << "' does not contain a function " << name << "()."; + return static_cast(func_handle); +} + +template +HandleType +SharedLibrary::LoadFunctionWithSignature(const char* name) const { + auto func_handle = reinterpret_cast(LoadFunction(name)); + CHECK(func_handle) << "Dynamic shared library `" << libpath_ << "' does not contain a function " + << name << "() with the requested signature"; + return func_handle; +} + +template +class PredFunctionInitDispatcher { + public: + inline static std::unique_ptr Dispatch( + const SharedLibrary& library, int num_feature, int num_output_group) { + return std::make_unique>( + library, num_feature, num_output_group); + } +}; + +std::unique_ptr +PredFunction::Create( + TypeInfo threshold_type, TypeInfo leaf_output_type, const SharedLibrary& library, + int num_feature, int num_output_group) { + return DispatchWithModelTypes( + threshold_type, leaf_output_type, library, num_feature, num_output_group); +} + +template +PredFunctionImpl::PredFunctionImpl( + const SharedLibrary& library, int num_feature, int num_output_group) { + CHECK_GT(num_output_group, 0) << "num_output_group cannot be zero"; + if (num_output_group > 1) { // multi-class classification + handle_ = library.LoadFunction("predict_multiclass"); + } else { // everything else + handle_ = library.LoadFunction("predict"); + } + num_feature_ = num_feature; + num_output_group_ = num_output_group; +} + +template +TypeInfo +PredFunctionImpl::GetThresholdType() const { + return InferTypeInfoOf(); +} + +template +TypeInfo +PredFunctionImpl::GetLeafOutputType() const { + return InferTypeInfoOf(); +} + +template +size_t +PredFunctionImpl::PredictBatch( + const DMatrix* dmat, size_t rbegin, size_t rend, bool pred_margin, + PredictorOutputHandle out_pred) const { + /* Pass the correct prediction function to PredLoop. + We also need to specify how the function should be called. */ + size_t result_size; + // Dimension of output vector: + // can be either [num_data] or [num_class]*[num_data]. + // Note that size of prediction may be smaller than out_pred (this occurs + // when pred_function is set to "max_index"). + CHECK(rbegin < rend && rend <= dmat->GetNumRow()); + size_t num_row = rend - rbegin; + if (num_output_group_ > 1) { // multi-class classification + using PredFunc = size_t (*)(Entry*, int, LeafOutputType*); + auto pred_func = reinterpret_cast(handle_); + CHECK(pred_func) << "The predict_multiclass() function has incorrect signature."; + auto pred_func_wrapper + = [pred_func, num_output_group = num_output_group_, pred_margin] + (int64_t rid, Entry* inst, LeafOutputType* out_pred) -> size_t { + return pred_func(inst, static_cast(pred_margin), + &out_pred[rid * num_output_group]); + }; + result_size = PredLoop(dmat, static_cast(0), num_feature_, rbegin, rend, + static_cast(out_pred), pred_func_wrapper); + } else { // everything else + using PredFunc = LeafOutputType (*)(Entry*, int); + auto pred_func = reinterpret_cast(handle_); + CHECK(pred_func) << "The predict() function has incorrect signature."; + auto pred_func_wrapper + = [pred_func, pred_margin] + (int64_t rid, Entry* inst, LeafOutputType* out_pred) -> size_t { + out_pred[rid] = pred_func(inst, static_cast(pred_margin)); + return 1; + }; + result_size = PredLoop(dmat, static_cast(0), num_feature_, rbegin, rend, + static_cast(out_pred), pred_func_wrapper); + } + return result_size; +} Predictor::Predictor(int num_worker_thread) - : lib_handle_(nullptr), - num_output_group_query_func_handle_(nullptr), - num_feature_query_func_handle_(nullptr), - pred_func_handle_(nullptr), + : pred_func_(nullptr), thread_pool_handle_(nullptr), - num_worker_thread_(num_worker_thread) {} + num_output_group_(0), + num_feature_(0), + sigmoid_alpha_(std::numeric_limits::quiet_NaN()), + global_bias_(std::numeric_limits::quiet_NaN()), + num_worker_thread_(num_worker_thread), + threshold_type_(TypeInfo::kInvalid), + leaf_output_type_(TypeInfo::kInvalid) {} Predictor::~Predictor() { - Free(); + if (thread_pool_handle_) { + Free(); + } } void -Predictor::Load(const char* name) { - lib_handle_ = OpenLibrary(name); - if (lib_handle_ == nullptr) { - LOG(FATAL) << "Failed to load dynamic shared library `" << name << "'"; - } +Predictor::Load(const char* libpath) { + lib_.Load(libpath); + + using UnsignedQueryFunc = size_t (*)(); + using StringQueryFunc = const char* (*)(); + using FloatQueryFunc = float (*)(); /* 1. query # of output groups */ - num_output_group_query_func_handle_ - = LoadFunction(lib_handle_, "get_num_output_group"); - using UnsignedQueryFunc = size_t (*)(void); - auto uint_query_func - = reinterpret_cast(num_output_group_query_func_handle_); - CHECK(uint_query_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid get_num_output_group() function"; - num_output_group_ = uint_query_func(); + auto* num_output_group_query_func + = lib_.LoadFunctionWithSignature("get_num_output_group"); + num_output_group_ = num_output_group_query_func(); /* 2. query # of features */ - num_feature_query_func_handle_ - = LoadFunction(lib_handle_, "get_num_feature"); - uint_query_func = reinterpret_cast(num_feature_query_func_handle_); - CHECK(uint_query_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid get_num_feature() function"; - num_feature_ = uint_query_func(); + auto* num_feature_query_func + = lib_.LoadFunctionWithSignature("get_num_feature"); + num_feature_ = num_feature_query_func(); CHECK_GT(num_feature_, 0) << "num_feature cannot be zero"; /* 3. query # of pred_transform name */ - pred_transform_query_func_handle_ - = LoadFunction(lib_handle_, "get_pred_transform"); - using StringQueryFunc = const char* (*)(void); - auto str_query_func = - reinterpret_cast(pred_transform_query_func_handle_); - if (str_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_pred_transform() function"; - pred_transform_ = "unknown"; - } else { - pred_transform_ = str_query_func(); - } + auto* pred_transform_query_func + = lib_.LoadFunctionWithSignature("get_pred_transform"); + pred_transform_ = pred_transform_query_func(); /* 4. query # of sigmoid_alpha */ - sigmoid_alpha_query_func_handle_ - = LoadFunction(lib_handle_, "get_sigmoid_alpha"); - using FloatQueryFunc = float (*)(void); - auto float_query_func = - reinterpret_cast(sigmoid_alpha_query_func_handle_); - if (float_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_sigmoid_alpha() function"; - sigmoid_alpha_ = NAN; - } else { - sigmoid_alpha_ = float_query_func(); - } + auto* sigmoid_alpha_query_func + = lib_.LoadFunctionWithSignature("get_sigmoid_alpha"); + sigmoid_alpha_ = sigmoid_alpha_query_func(); /* 5. query # of global_bias */ - global_bias_query_func_handle_ - = LoadFunction(lib_handle_, "get_global_bias"); - float_query_func = reinterpret_cast(global_bias_query_func_handle_); - if (float_query_func == nullptr) { - LOG(INFO) << "Dynamic shared library `" << name - << "' does not contain valid get_global_bias() function"; - global_bias_ = NAN; - } else { - global_bias_ = float_query_func(); - } - - /* 6. load appropriate function for margin prediction */ + auto* global_bias_query_func = lib_.LoadFunctionWithSignature("get_global_bias"); + global_bias_ = global_bias_query_func(); + + /* 6. Query the data type for thresholds and leaf outputs */ + auto* threshold_type_query_func + = lib_.LoadFunctionWithSignature("get_threshold_type"); + threshold_type_ = typeinfo_table.at(threshold_type_query_func()); + auto* leaf_output_type_query_func + = lib_.LoadFunctionWithSignature("get_leaf_output_type"); + leaf_output_type_ = typeinfo_table.at(leaf_output_type_query_func()); + + /* 7. load appropriate function for margin prediction */ CHECK_GT(num_output_group_, 0) << "num_output_group cannot be zero"; - if (num_output_group_ > 1) { // multi-class classification - pred_func_handle_ = LoadFunction(lib_handle_, - "predict_multiclass"); - using PredFunc = size_t (*)(TreelitePredictorEntry*, int, float*); - PredFunc pred_func = reinterpret_cast(pred_func_handle_); - CHECK(pred_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid predict_multiclass() function"; - } else { // everything else - pred_func_handle_ = LoadFunction(lib_handle_, "predict"); - using PredFunc = float (*)(TreelitePredictorEntry*, int); - PredFunc pred_func = reinterpret_cast(pred_func_handle_); - CHECK(pred_func != nullptr) - << "Dynamic shared library `" << name - << "' does not contain valid predict() function"; - } + pred_func_ = PredFunction::Create( + threshold_type_, leaf_output_type_, lib_, + static_cast(num_feature_), static_cast(num_output_group_)); if (num_worker_thread_ == -1) { - num_worker_thread_ = std::thread::hardware_concurrency(); + num_worker_thread_ = static_cast(std::thread::hardware_concurrency()); } thread_pool_handle_ = static_cast( new PredThreadPool(num_worker_thread_ - 1, this, [](SpscQueue* incoming_queue, SpscQueue* outgoing_queue, const Predictor* predictor) { - InputToken input; - while (incoming_queue->Pop(&input)) { - size_t query_result_size; - const size_t rbegin = input.rbegin; - const size_t rend = input.rend; - switch (input.input_type) { - case InputType::kSparseBatch: - { - const CSRBatch* batch = static_cast(input.data); - query_result_size - = PredictBatch_(batch, input.pred_margin, input.num_feature, - input.num_output_group, input.pred_func_handle, - rbegin, rend, - predictor->QueryResultSize(batch, rbegin, rend), - input.out_pred); - } - break; - case InputType::kDenseBatch: - { - const DenseBatch* batch = static_cast(input.data); - query_result_size - = PredictBatch_(batch, input.pred_margin, input.num_feature, - input.num_output_group, input.pred_func_handle, - rbegin, rend, - predictor->QueryResultSize(batch, rbegin, rend), - input.out_pred); - } - break; + predictor->exception_catcher_.Run([&]() { + InputToken input; + while (incoming_queue->Pop(&input)) { + const size_t rbegin = input.rbegin; + const size_t rend = input.rend; + size_t query_result_size + = predictor->pred_func_->PredictBatch( + input.dmat, rbegin, rend, input.pred_margin, input.out_pred); + outgoing_queue->Push(OutputToken{query_result_size}); } - outgoing_queue->Push(OutputToken{query_result_size}); - } + }); })); } void Predictor::Free() { - CloseLibrary(lib_handle_); delete static_cast(thread_pool_handle_); } -template static inline -std::vector SplitBatch(const BatchType* batch, size_t split_factor) { - const size_t num_row = batch->num_row; +std::vector SplitBatch(const DMatrix* dmat, size_t split_factor) { + const size_t num_row = dmat->GetNumRow(); CHECK_LE(split_factor, num_row); const size_t portion = num_row / split_factor; const size_t remainder = num_row % split_factor; @@ -379,26 +392,38 @@ std::vector SplitBatch(const BatchType* batch, size_t split_factor) { return row_ptr; } -template -inline size_t -Predictor::PredictBatchBase_(const BatchType* batch, int verbose, - bool pred_margin, float* out_result) { - static_assert(std::is_same::value - || std::is_same::value, - "PredictBatchBase_: unrecognized batch type"); +template +class ShrinkResultToFit { + public: + inline static void Dispatch( + size_t num_row, size_t query_size_per_instance, size_t num_output_group, + PredictorOutputHandle out_result); +}; + +template +class AllocateOutputVector { + public: + inline static PredictorOutputHandle Dispatch(size_t size); +}; + +template +class DeallocateOutputVector { + public: + inline static void Dispatch(PredictorOutputHandle output_vector); +}; + +size_t +Predictor::PredictBatch( + const DMatrix* dmat, int verbose, bool pred_margin, PredictorOutputHandle out_result) const { const double tstart = dmlc::GetTime(); - PredThreadPool* pool = static_cast(thread_pool_handle_); - const InputType input_type - = std::is_same::value - ? InputType::kSparseBatch : InputType::kDenseBatch; - InputToken request{input_type, static_cast(batch), pred_margin, - num_feature_, num_output_group_, pred_func_handle_, - 0, batch->num_row, out_result}; + + const size_t num_row = dmat->GetNumRow(); + auto* pool = static_cast(thread_pool_handle_); + InputToken request{dmat, pred_margin, pred_func_.get(), 0, num_row, out_result}; OutputToken response; - CHECK_GT(batch->num_row, 0); - const int nthread = std::min(num_worker_thread_, - static_cast(batch->num_row)); - const std::vector row_ptr = SplitBatch(batch, nthread); + CHECK_GT(num_row, 0); + const int nthread = std::min(num_worker_thread_, static_cast(num_row)); + const std::vector row_ptr = SplitBatch(dmat, nthread); for (int tid = 0; tid < nthread - 1; ++tid) { request.rbegin = row_ptr[tid]; request.rend = row_ptr[tid + 1]; @@ -406,14 +431,11 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, } size_t total_size = 0; { - // assign work to master + // assign work to the main thread const size_t rbegin = row_ptr[nthread - 1]; const size_t rend = row_ptr[nthread]; const size_t query_result_size - = PredictBatch_(batch, pred_margin, num_feature_, num_output_group_, - pred_func_handle_, - rbegin, rend, QueryResultSize(batch, rbegin, rend), - out_result); + = pred_func_->PredictBatch(dmat, rbegin, rend, pred_margin, out_result); total_size += query_result_size; } for (int tid = 0; tid < nthread - 1; ++tid) { @@ -422,47 +444,57 @@ Predictor::PredictBatchBase_(const BatchType* batch, int verbose, } } // re-shape output if total_size < dimension of out_result - if (total_size < QueryResultSize(batch, 0, batch->num_row)) { + if (total_size < QueryResultSize(dmat, 0, num_row)) { CHECK_GT(num_output_group_, 1); - CHECK_EQ(total_size % batch->num_row, 0); - const size_t query_size_per_instance = total_size / batch->num_row; + CHECK_EQ(total_size % num_row, 0); + const size_t query_size_per_instance = total_size / num_row; CHECK_GT(query_size_per_instance, 0); CHECK_LT(query_size_per_instance, num_output_group_); - for (size_t rid = 0; rid < batch->num_row; ++rid) { - for (size_t k = 0; k < query_size_per_instance; ++k) { - out_result[rid * query_size_per_instance + k] - = out_result[rid * num_output_group_ + k]; - } - } + DispatchWithTypeInfo( + leaf_output_type_, num_row, query_size_per_instance, num_output_group_, out_result); } const double tend = dmlc::GetTime(); if (verbose > 0) { - LOG(INFO) << "Treelite: Finished prediction in " - << tend - tstart << " sec"; + LOG(INFO) << "Treelite: Finished prediction in " << tend - tstart << " sec"; } return total_size; } -size_t -Predictor::PredictBatch(const CSRBatch* batch, int verbose, - bool pred_margin, float* out_result) { - return PredictBatchBase_(batch, verbose, pred_margin, out_result); +PredictorOutputHandle +Predictor::CreateOutputVector(const DMatrix* dmat) const { + const size_t output_vector_size = this->QueryResultSize(dmat); + return DispatchWithTypeInfo(leaf_output_type_, output_vector_size); } -size_t -Predictor::PredictBatch(const DenseBatch* batch, int verbose, - bool pred_margin, float* out_result) { - return PredictBatchBase_(batch, verbose, pred_margin, out_result); +void +Predictor::DeleteOutputVector(PredictorOutputHandle output_vector) const { + DispatchWithTypeInfo(leaf_output_type_, output_vector); } -size_t -Predictor::PredictInst(TreelitePredictorEntry* inst, bool pred_margin, - float* out_result) { - size_t total_size; - total_size = PredictInst_(inst, pred_margin, num_output_group_, - pred_func_handle_, - QueryResultSizeSingleInst(), out_result); - return total_size; +template +void +ShrinkResultToFit::Dispatch( + size_t num_row, size_t query_size_per_instance, size_t num_output_group, + PredictorOutputHandle out_result) { + auto* out_result_ = static_cast(out_result); + for (size_t rid = 0; rid < num_row; ++rid) { + for (size_t k = 0; k < query_size_per_instance; ++k) { + out_result_[rid * query_size_per_instance + k] = out_result_[rid * num_output_group + k]; + } + } +} + +template +PredictorOutputHandle +AllocateOutputVector::Dispatch(size_t size) { + return static_cast(new LeafOutputType[size]); +} + +template +void +DeallocateOutputVector::Dispatch(PredictorOutputHandle output_vector) { + delete[] (static_cast(output_vector)); } +} // namespace predictor } // namespace treelite diff --git a/src/predictor/thread_pool/spsc_queue.h b/src/predictor/thread_pool/spsc_queue.h index 1fa10416..539c4bd8 100644 --- a/src/predictor/thread_pool/spsc_queue.h +++ b/src/predictor/thread_pool/spsc_queue.h @@ -14,6 +14,9 @@ #include #include +namespace treelite { +namespace predictor { + const constexpr int kL1CacheBytes = 64; /*! \brief Lock-free single-producer-single-consumer queue for each thread */ @@ -117,4 +120,7 @@ class SpscQueue { std::condition_variable cv_; }; +} // namespace predictor +} // namespace treelite + #endif // TREELITE_PREDICTOR_THREAD_POOL_SPSC_QUEUE_H_ diff --git a/src/predictor/thread_pool/thread_pool.h b/src/predictor/thread_pool/thread_pool.h index c58ed4a7..16f07c1e 100644 --- a/src/predictor/thread_pool/thread_pool.h +++ b/src/predictor/thread_pool/thread_pool.h @@ -7,6 +7,7 @@ #ifndef TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ #define TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ +#include #include #include #include @@ -18,6 +19,7 @@ #include "spsc_queue.h" namespace treelite { +namespace predictor { template class ThreadPool { @@ -27,7 +29,8 @@ class ThreadPool { ThreadPool(int num_worker, const TaskContext* context, TaskFunc task) : num_worker_(num_worker), context_(context), task_(task) { - CHECK(num_worker_ >= 0 && num_worker_ < std::thread::hardware_concurrency()) + CHECK(num_worker_ >= 0 + && static_cast(num_worker_) < std::thread::hardware_concurrency()) << "Number of worker threads must be between 0 and " << (std::thread::hardware_concurrency() - 1); for (int i = 0; i < num_worker_; ++i) { @@ -42,7 +45,7 @@ class ThreadPool { } /* bind threads to cores */ const char* bind_flag = getenv("TREELITE_BIND_THREADS"); - if (bind_flag == nullptr || std::atoi(bind_flag) == 1) { + if (bind_flag == nullptr || std::stoi(bind_flag) == 1) { SetAffinity(); } } @@ -76,7 +79,7 @@ class ThreadPool { SetThreadAffinityMask(GetCurrentThread(), 0x1); for (int i = 0; i < num_worker_; ++i) { const int core_id = i + 1; - SetThreadAffinityMask(thread_[i].native_handle(), (1 << core_id)); + SetThreadAffinityMask(thread_[i].native_handle(), (1ULL << core_id)); } #elif defined(__APPLE__) && defined(__MACH__) #include @@ -122,6 +125,7 @@ class ThreadPool { } }; +} // namespace predictor } // namespace treelite #endif // TREELITE_PREDICTOR_THREAD_POOL_THREAD_POOL_H_ diff --git a/src/typeinfo.cc b/src/typeinfo.cc new file mode 100644 index 00000000..048c0e79 --- /dev/null +++ b/src/typeinfo.cc @@ -0,0 +1,22 @@ +/*! + * Copyright (c) 2017-2020 by Contributors + * \file typeinfo.cc + * \author Hyunsu Cho + * \brief Conversion tables to obtain TypeInfo from string + */ + +// Do not include other Treelite headers here, to minimize cross-header dependencies + +#include +#include +#include + +namespace treelite { + +const std::unordered_map typeinfo_table{ + {"uint32", TypeInfo::kUInt32}, + {"float32", TypeInfo::kFloat32}, + {"float64", TypeInfo::kFloat64} +}; + +} // namespace treelite From dc188113c8dae052c60d9a7638b02564c60d6631 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 16 Sep 2020 19:36:48 -0700 Subject: [PATCH 36/38] Refactor struct Model -> class Model + class ModelImpl (Part of #196) (#201) * Upgrade C++ standard to C++14. * Split struct Model into class Model and class ModelImpl. The ModelImpl class will soon become a template class in order to hold Tree objects with uint32, float32, or float64 type. The Model class will become an abstract class so as to avoid exposing ModelImpl to external interface. (It's very hard to pass template classes through a FFI boundary.) * Change signature of methods that return Model, since Model is now an abstract class. These functions now return std::unique_ptr. * Move bodies of tiny methods from tree_impl.h to tree.h. This will reduce verbosity once ModelImpl becomes a template class. --- include/treelite/c_api.h | 6 +- include/treelite/compiler.h | 2 +- include/treelite/frontend.h | 18 +- include/treelite/tree.h | 184 +++++++++++++----- include/treelite/tree_impl.h | 278 +++++++-------------------- src/annotator.cc | 7 +- src/c_api/c_api.cc | 37 ++-- src/compiler/ast/build.cc | 2 +- src/compiler/ast/builder.h | 2 +- src/compiler/ast_native.cc | 2 +- src/compiler/failsafe.cc | 11 +- src/compiler/native/pred_transform.h | 4 +- src/frontend/builder.cc | 9 +- src/frontend/lightgbm.cc | 54 +++--- src/frontend/xgboost.cc | 48 ++--- src/reference_serializer.cc | 2 +- tests/cpp/CMakeLists.txt | 2 +- tests/cpp/test_serializer.cc | 18 +- tests/example_app/CMakeLists.txt | 2 +- tests/example_app/example.cc | 7 +- 20 files changed, 324 insertions(+), 371 deletions(-) diff --git a/include/treelite/c_api.h b/include/treelite/c_api.h index 8f0f5285..5d87743f 100644 --- a/include/treelite/c_api.h +++ b/include/treelite/c_api.h @@ -229,8 +229,7 @@ TREELITE_DLL int TreeliteCompilerFree(CompilerHandle handle); * \param out loaded model * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreeliteLoadLightGBMModel(const char* filename, - ModelHandle* out); +TREELITE_DLL int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out); /*! * \brief load a model file generated by XGBoost (dmlc/xgboost). The model file * must contain a decision tree ensemble. @@ -238,8 +237,7 @@ TREELITE_DLL int TreeliteLoadLightGBMModel(const char* filename, * \param out loaded model * \return 0 for success, -1 for failure */ -TREELITE_DLL int TreeliteLoadXGBoostModel(const char* filename, - ModelHandle* out); +TREELITE_DLL int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out); /*! * \brief load an XGBoost model from a memory buffer. * \param buf memory buffer diff --git a/include/treelite/compiler.h b/include/treelite/compiler.h index a5dae46f..c4157746 100644 --- a/include/treelite/compiler.h +++ b/include/treelite/compiler.h @@ -17,7 +17,7 @@ namespace treelite { -struct Model; // forward declaration +class Model; // forward declaration namespace compiler { diff --git a/include/treelite/frontend.h b/include/treelite/frontend.h index 67272ce5..ca974b42 100644 --- a/include/treelite/frontend.h +++ b/include/treelite/frontend.h @@ -14,7 +14,7 @@ namespace treelite { -struct Model; // forward declaration +class Model; // forward declaration namespace frontend { @@ -25,23 +25,23 @@ namespace frontend { * \brief load a model file generated by LightGBM (Microsoft/LightGBM). The * model file must contain a decision tree ensemble. * \param filename name of model file - * \param out reference to loaded model + * \return loaded model */ -void LoadLightGBMModel(const char *filename, Model* out); +std::unique_ptr LoadLightGBMModel(const char *filename); /*! * \brief load a model file generated by XGBoost (dmlc/xgboost). The model file * must contain a decision tree ensemble. * \param filename name of model file - * \param out reference to loaded model + * \return loaded model */ -void LoadXGBoostModel(const char* filename, Model* out); +std::unique_ptr LoadXGBoostModel(const char* filename); /*! * \brief load an XGBoost model from a memory buffer. * \param buf memory buffer * \param len size of memory buffer - * \param out reference to loaded model + * \return loaded model */ -void LoadXGBoostModel(const void* buf, size_t len, Model* out); +std::unique_ptr LoadXGBoostModel(const void* buf, size_t len); //-------------------------------------------------------------------------- // model builder interface: build trees incrementally @@ -185,9 +185,9 @@ class ModelBuilder { void DeleteTree(int index); /*! * \brief finalize the model and produce the in-memory representation - * \param out_model place to store in-memory representation of the finished model + * \return the finished model */ - void CommitModel(Model* out_model); + std::unique_ptr CommitModel(); private: std::unique_ptr pimpl; // Pimpl pattern diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 6d9a085f..ba1d89c0 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -133,8 +134,6 @@ class Tree { bool sum_hess_present_; /*! \brief whether gain_present_ field is present */ bool gain_present_; - // padding - uint16_t pad_; }; static_assert(std::is_pod::value, "Node must be a POD type"); @@ -144,12 +143,12 @@ class Tree { ~Tree() = default; Tree(const Tree&) = delete; Tree& operator=(const Tree&) = delete; - Tree(Tree&&) = default; - Tree& operator=(Tree&&) = default; - inline Tree Clone() const; + Tree(Tree&&) noexcept = default; + Tree& operator=(Tree&&) noexcept = default; - inline std::vector GetPyBuffer(); - inline void InitFromPyBuffer(std::vector frames); + inline void GetPyBuffer(std::vector* dest); + inline void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end); private: // vector of nodes @@ -184,57 +183,86 @@ class Tree { * \brief index of the node's left child * \param nid ID of node being queried */ - inline int LeftChild(int nid) const; + inline int LeftChild(int nid) const { + return nodes_[nid].cleft_; + } /*! * \brief index of the node's right child * \param nid ID of node being queried */ - inline int RightChild(int nid) const; + inline int RightChild(int nid) const { + return nodes_[nid].cright_; + } /*! * \brief index of the node's "default" child, used when feature is missing * \param nid ID of node being queried */ - inline int DefaultChild(int nid) const; + inline int DefaultChild(int nid) const { + return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid); + } /*! * \brief feature index of the node's split condition * \param nid ID of node being queried */ - inline uint32_t SplitIndex(int nid) const; + inline uint32_t SplitIndex(int nid) const { + return (nodes_[nid].sindex_ & ((1U << 31U) - 1U)); + } /*! * \brief whether to use the left child node, when the feature in the split condition is missing * \param nid ID of node being queried */ - inline bool DefaultLeft(int nid) const; + inline bool DefaultLeft(int nid) const { + return (nodes_[nid].sindex_ >> 31U) != 0; + } /*! * \brief whether the node is leaf node * \param nid ID of node being queried */ - inline bool IsLeaf(int nid) const; + inline bool IsLeaf(int nid) const { + return nodes_[nid].cleft_ == -1; + } /*! * \brief get leaf value of the leaf node * \param nid ID of node being queried */ - inline tl_float LeafValue(int nid) const; + inline tl_float LeafValue(int nid) const { + return (nodes_[nid].info_).leaf_value; + } /*! * \brief get leaf vector of the leaf node; useful for multi-class random forest classifier * \param nid ID of node being queried */ - inline std::vector LeafVector(int nid) const; + inline std::vector LeafVector(int nid) const { + if (nid > leaf_vector_offset_.Size()) { + throw std::runtime_error("nid too large"); + } + return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], + &leaf_vector_[leaf_vector_offset_[nid + 1]]); + } /*! * \brief tests whether the leaf node has a non-empty leaf vector * \param nid ID of node being queried */ - inline bool HasLeafVector(int nid) const; + inline bool HasLeafVector(int nid) const { + if (nid > leaf_vector_offset_.Size()) { + throw std::runtime_error("nid too large"); + } + return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1]; + } /*! * \brief get threshold of the node * \param nid ID of node being queried */ - inline tl_float Threshold(int nid) const; + inline tl_float Threshold(int nid) const { + return (nodes_[nid].info_).threshold; + } /*! * \brief get comparison operator * \param nid ID of node being queried */ - inline Operator ComparisonOp(int nid) const; + inline Operator ComparisonOp(int nid) const { + return nodes_[nid].cmp_; + } /*! * \brief Get list of all categories belonging to the left child node. Categories not in this * list will belong to the right child node. Categories are integers ranging from 0 to @@ -242,48 +270,71 @@ class Tree { * assumed to be in ascending order. * \param nid ID of node being queried */ - inline std::vector LeftCategories(int nid) const; + inline std::vector LeftCategories(int nid) const { + if (nid > left_categories_offset_.Size()) { + throw std::runtime_error("nid too large"); + } + return std::vector(&left_categories_[left_categories_offset_[nid]], + &left_categories_[left_categories_offset_[nid + 1]]); + } /*! * \brief get feature split type * \param nid ID of node being queried */ - inline SplitFeatureType SplitType(int nid) const; + inline SplitFeatureType SplitType(int nid) const { + return nodes_[nid].split_type_; + } /*! * \brief test whether this node has data count * \param nid ID of node being queried */ - inline bool HasDataCount(int nid) const; + inline bool HasDataCount(int nid) const { + return nodes_[nid].data_count_present_; + } /*! * \brief get data count * \param nid ID of node being queried */ - inline uint64_t DataCount(int nid) const; + inline uint64_t DataCount(int nid) const { + return nodes_[nid].data_count_; + } + /*! * \brief test whether this node has hessian sum * \param nid ID of node being queried */ - inline bool HasSumHess(int nid) const; + inline bool HasSumHess(int nid) const { + return nodes_[nid].sum_hess_present_; + } /*! * \brief get hessian sum * \param nid ID of node being queried */ - inline double SumHess(int nid) const; + inline double SumHess(int nid) const { + return nodes_[nid].sum_hess_; + } /*! * \brief test whether this node has gain value * \param nid ID of node being queried */ - inline bool HasGain(int nid) const; + inline bool HasGain(int nid) const { + return nodes_[nid].gain_present_; + } /*! * \brief get gain value * \param nid ID of node being queried */ - inline double Gain(int nid) const; + inline double Gain(int nid) const { + return nodes_[nid].gain_; + } /*! * \brief test whether missing values should be converted into zero; only applicable for * categorical splits * \param nid ID of node being queried */ - inline bool MissingCategoryToZero(int nid) const; + inline bool MissingCategoryToZero(int nid) const { + return nodes_[nid].missing_category_to_zero_; + } /** Setters **/ /*! @@ -326,19 +377,31 @@ class Tree { * \param nid ID of node being updated * \param sum_hess hessian sum */ - inline void SetSumHess(int nid, double sum_hess); + inline void SetSumHess(int nid, double sum_hess) { + Node& node = nodes_[nid]; + node.sum_hess_ = sum_hess; + node.sum_hess_present_ = true; + } /*! * \brief set the data count of the node * \param nid ID of node being updated * \param data_count data count */ - inline void SetDataCount(int nid, uint64_t data_count); + inline void SetDataCount(int nid, uint64_t data_count) { + Node& node = nodes_[nid]; + node.data_count_ = data_count; + node.data_count_present_ = true; + } /*! * \brief set the gain value of the node * \param nid ID of node being updated * \param gain gain value */ - inline void SetGain(int nid, double gain); + inline void SetGain(int nid, double gain) { + Node& node = nodes_[nid]; + node.gain_ = gain; + node.gain_present_ = true; + } void ReferenceSerialize(dmlc::Stream* fo) const; }; @@ -406,9 +469,24 @@ inline void InitParamAndCheck(ModelParam* param, const std::vector>& cfg); /*! \brief thin wrapper for tree ensemble model */ -struct Model { - /*! \brief member trees */ - std::vector trees; +class Model { + public: + /*! \brief disable copy; use default move */ + Model() = default; + virtual ~Model() = default; + inline static std::unique_ptr Create(); + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + Model(Model&&) = default; + Model& operator=(Model&&) = default; + + virtual size_t GetNumTree() const = 0; + virtual void SetTreeLimit(size_t limit) = 0; + virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0; + + inline std::vector GetPyBuffer(); + inline static std::unique_ptr CreateFromPyBuffer(std::vector frames); + /*! * \brief number of features used for the model. * It is assumed that all feature indices are between 0 and [num_feature]-1. @@ -423,19 +501,37 @@ struct Model { /*! \brief extra parameters */ ModelParam param; - /*! \brief disable copy; use default move */ - Model() = default; - ~Model() = default; - Model(const Model&) = delete; - Model& operator=(const Model&) = delete; - Model(Model&&) = default; - Model& operator=(Model&&) = default; + private: + // Internal functions for serialization + virtual void GetPyBuffer(std::vector* dest) = 0; + virtual void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end) = 0; +}; - void ReferenceSerialize(dmlc::Stream* fo) const; +class ModelImpl : public Model { + public: + /*! \brief member trees */ + std::vector trees; - inline std::vector GetPyBuffer(); - inline void InitFromPyBuffer(std::vector frames); - inline Model Clone() const; + /*! \brief disable copy; use default move */ + ModelImpl() = default; + ~ModelImpl() override = default; + ModelImpl(const ModelImpl&) = delete; + ModelImpl& operator=(const ModelImpl&) = delete; + ModelImpl(ModelImpl&&) noexcept = default; + ModelImpl& operator=(ModelImpl&&) noexcept = default; + + void ReferenceSerialize(dmlc::Stream* fo) const override; + inline size_t GetNumTree() const override { + return trees.size(); + } + void SetTreeLimit(size_t limit) override { + return trees.resize(limit); + } + + inline void GetPyBuffer(std::vector* dest) override; + inline void InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end) override; }; } // namespace treelite diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 374523fe..85ac40dc 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -362,73 +363,31 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { constexpr size_t kNumFramePerTree = 6; -inline std::vector -Tree::GetPyBuffer() { - return { - GetPyBufferFromScalar(&num_nodes), - GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}"), - GetPyBufferFromArray(&leaf_vector_), - GetPyBufferFromArray(&leaf_vector_offset_), - GetPyBufferFromArray(&left_categories_), - GetPyBufferFromArray(&left_categories_offset_) - }; -} - inline void -Tree::InitFromPyBuffer(std::vector frames) { - size_t frame_id = 0; - InitScalarFromPyBuffer(&num_nodes, frames[frame_id++]); - InitArrayFromPyBuffer(&nodes_, frames[frame_id++]); - if (num_nodes != nodes_.Size()) { - throw std::runtime_error("Could not load the correct number of nodes"); - } - InitArrayFromPyBuffer(&leaf_vector_, frames[frame_id++]); - InitArrayFromPyBuffer(&leaf_vector_offset_, frames[frame_id++]); - InitArrayFromPyBuffer(&left_categories_, frames[frame_id++]); - InitArrayFromPyBuffer(&left_categories_offset_, frames[frame_id++]); - if (frame_id != kNumFramePerTree) { - throw std::runtime_error("Wrong number of frames loaded"); - } -} - -inline std::vector -Model::GetPyBuffer() { - /* Header */ - std::vector frames{ - GetPyBufferFromScalar(&num_feature), - GetPyBufferFromScalar(&num_output_group), - GetPyBufferFromScalar(&random_forest_flag), - GetPyBufferFromScalar(¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}") - }; - - /* Body */ - for (auto& tree : trees) { - auto tree_frames = tree.GetPyBuffer(); - frames.insert(frames.end(), tree_frames.begin(), tree_frames.end()); - } - return frames; +Tree::GetPyBuffer(std::vector* dest) { + dest->push_back(GetPyBufferFromScalar(&num_nodes)); + dest->push_back(GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}")); + dest->push_back(GetPyBufferFromArray(&leaf_vector_)); + dest->push_back(GetPyBufferFromArray(&leaf_vector_offset_)); + dest->push_back(GetPyBufferFromArray(&left_categories_)); + dest->push_back(GetPyBufferFromArray(&left_categories_offset_)); } inline void -Model::InitFromPyBuffer(std::vector frames) { - /* Header */ - size_t frame_id = 0; - InitScalarFromPyBuffer(&num_feature, frames[frame_id++]); - InitScalarFromPyBuffer(&num_output_group, frames[frame_id++]); - InitScalarFromPyBuffer(&random_forest_flag, frames[frame_id++]); - InitScalarFromPyBuffer(¶m, frames[frame_id++]); - /* Body */ - const size_t num_frame = frames.size(); - if ((num_frame - frame_id) % kNumFramePerTree != 0) { - throw std::runtime_error("Wrong number of frames"); +Tree::InitFromPyBuffer(std::vector::iterator begin, + std::vector::iterator end) { + if (std::distance(begin, end) != kNumFramePerTree) { + throw std::runtime_error("Wrong number of frames specified"); } - trees.clear(); - for (; frame_id < num_frame; frame_id += kNumFramePerTree) { - std::vector tree_frames(frames.begin() + frame_id, - frames.begin() + frame_id + kNumFramePerTree); - trees.emplace_back(); - trees.back().InitFromPyBuffer(tree_frames); + InitScalarFromPyBuffer(&num_nodes, *begin++); + InitArrayFromPyBuffer(&nodes_, *begin++); + if (num_nodes != nodes_.Size()) { + throw std::runtime_error("Could not load the correct number of nodes"); } + InitArrayFromPyBuffer(&leaf_vector_, *begin++); + InitArrayFromPyBuffer(&leaf_vector_offset_, *begin++); + InitArrayFromPyBuffer(&left_categories_, *begin++); + InitArrayFromPyBuffer(&left_categories_offset_, *begin++); } inline void Tree::Node::Init() { @@ -442,7 +401,6 @@ inline void Tree::Node::Init() { data_count_present_ = sum_hess_present_ = gain_present_ = false; split_type_ = SplitFeatureType::kNone; cmp_ = Operator::kNone; - pad_ = 0; } inline int @@ -460,18 +418,6 @@ Tree::AllocNode() { return nd; } -inline Tree -Tree::Clone() const { - Tree tree; - tree.num_nodes = num_nodes; - tree.nodes_ = nodes_.Clone(); - tree.leaf_vector_ = leaf_vector_.Clone(); - tree.leaf_vector_offset_ = leaf_vector_offset_.Clone(); - tree.left_categories_ = left_categories_.Clone(); - tree.left_categories_offset_ = left_categories_offset_.Clone(); - return tree; -} - inline void Tree::Init() { num_nodes = 1; @@ -520,117 +466,6 @@ Tree::GetCategoricalFeatures() const { return result; } -inline int -Tree::LeftChild(int nid) const { - return nodes_[nid].cleft_; -} - -inline int -Tree::RightChild(int nid) const { - return nodes_[nid].cright_; -} - -inline int -Tree::DefaultChild(int nid) const { - return DefaultLeft(nid) ? LeftChild(nid) : RightChild(nid); -} - -inline uint32_t -Tree::SplitIndex(int nid) const { - return (nodes_[nid].sindex_ & ((1U << 31U) - 1U)); -} - -inline bool -Tree::DefaultLeft(int nid) const { - return (nodes_[nid].sindex_ >> 31U) != 0; -} - -inline bool -Tree::IsLeaf(int nid) const { - return nodes_[nid].cleft_ == -1; -} - -inline tl_float -Tree::LeafValue(int nid) const { - return (nodes_[nid].info_).leaf_value; -} - -inline std::vector -Tree::LeafVector(int nid) const { - if (nid > leaf_vector_offset_.Size()) { - throw std::runtime_error("nid too large"); - } - return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], - &leaf_vector_[leaf_vector_offset_[nid + 1]]); -} - -inline bool -Tree::HasLeafVector(int nid) const { - if (nid > leaf_vector_offset_.Size()) { - throw std::runtime_error("nid too large"); - } - return leaf_vector_offset_[nid] != leaf_vector_offset_[nid + 1]; -} - -inline tl_float -Tree::Threshold(int nid) const { - return (nodes_[nid].info_).threshold; -} - -inline Operator -Tree::ComparisonOp(int nid) const { - return nodes_[nid].cmp_; -} - -inline std::vector -Tree::LeftCategories(int nid) const { - if (nid > left_categories_offset_.Size()) { - throw std::runtime_error("nid too large"); - } - return std::vector(&left_categories_[left_categories_offset_[nid]], - &left_categories_[left_categories_offset_[nid + 1]]); -} - -inline SplitFeatureType -Tree::SplitType(int nid) const { - return nodes_[nid].split_type_; -} - -inline bool -Tree::HasDataCount(int nid) const { - return nodes_[nid].data_count_present_; -} - -inline uint64_t -Tree::DataCount(int nid) const { - return nodes_[nid].data_count_; -} - -inline bool -Tree::HasSumHess(int nid) const { - return nodes_[nid].sum_hess_present_; -} - -inline double -Tree::SumHess(int nid) const { - return nodes_[nid].sum_hess_; -} - -inline bool -Tree::HasGain(int nid) const { - return nodes_[nid].gain_present_; -} - -inline double -Tree::Gain(int nid) const { - return nodes_[nid].gain_; -} - -inline bool -Tree::MissingCategoryToZero(int nid) const { - return nodes_[nid].missing_category_to_zero_; -} - inline void Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, bool default_left, Operator cmp) { @@ -712,38 +547,63 @@ Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { node.split_type_ = SplitFeatureType::kNone; } -inline void -Tree::SetSumHess(int nid, double sum_hess) { - Node& node = nodes_[nid]; - node.sum_hess_ = sum_hess; - node.sum_hess_present_ = true; +inline std::unique_ptr +Model::Create() { + std::unique_ptr model = std::make_unique(); + return model; } -inline void -Tree::SetDataCount(int nid, uint64_t data_count) { - Node& node = nodes_[nid]; - node.data_count_ = data_count; - node.data_count_present_ = true; +inline std::vector +Model::GetPyBuffer() { + std::vector buffer; + this->GetPyBuffer(&buffer); + return buffer; +} + +inline std::unique_ptr +Model::CreateFromPyBuffer(std::vector frames) { + std::unique_ptr model = Model::Create(); + model->InitFromPyBuffer(frames.begin(), frames.end()); + return model; } inline void -Tree::SetGain(int nid, double gain) { - Node& node = nodes_[nid]; - node.gain_ = gain; - node.gain_present_ = true; +ModelImpl::GetPyBuffer(std::vector* dest) { + /* Header */ + dest->push_back(GetPyBufferFromScalar(&num_feature)); + dest->push_back(GetPyBufferFromScalar(&num_output_group)); + dest->push_back(GetPyBufferFromScalar(&random_forest_flag)); + dest->push_back(GetPyBufferFromScalar( + ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); + + /* Body */ + for (Tree& tree : trees) { + tree.GetPyBuffer(dest); + } } -inline Model -Model::Clone() const { - Model model; - for (const Tree& t : trees) { - model.trees.push_back(t.Clone()); +inline void +ModelImpl::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { + const size_t num_frame = std::distance(begin, end); + /* Header */ + constexpr size_t kNumFrameInHeader = 4; + if (num_frame < kNumFrameInHeader) { + throw std::runtime_error("Wrong number of frames"); + } + InitScalarFromPyBuffer(&num_feature, *begin++); + InitScalarFromPyBuffer(&num_output_group, *begin++); + InitScalarFromPyBuffer(&random_forest_flag, *begin++); + InitScalarFromPyBuffer(¶m, *begin++); + /* Body */ + if ((num_frame - kNumFrameInHeader) % kNumFramePerTree != 0) { + throw std::runtime_error("Wrong number of frames"); + } + trees.clear(); + for (; begin < end; begin += kNumFramePerTree) { + trees.emplace_back(); + trees.back().InitFromPyBuffer(begin, begin + kNumFramePerTree); } - model.num_feature = num_feature; - model.num_output_group = num_output_group; - model.random_forest_flag = random_forest_flag; - model.param = param; - return model; } inline void InitParamAndCheck(ModelParam* param, diff --git a/src/annotator.cc b/src/annotator.cc index 36b90d22..6a43b9b9 100644 --- a/src/annotator.cc +++ b/src/annotator.cc @@ -54,11 +54,12 @@ void Traverse(const treelite::Tree& tree, const Entry* data, Traverse_(tree, data, 0, out_counts); } -inline void ComputeBranchLoop(const treelite::Model& model, +inline void ComputeBranchLoop(const treelite::Model& model_ptr, const treelite::DMatrix* dmat, size_t rbegin, size_t rend, int nthread, const size_t* count_row_ptr, size_t* counts_tloc, Entry* inst) { + const treelite::ModelImpl& model = dynamic_cast(model_ptr); const size_t ntree = model.trees.size(); CHECK_LE(rbegin, rend); CHECK_LT(static_cast(rend), std::numeric_limits::max()); @@ -89,8 +90,10 @@ inline void ComputeBranchLoop(const treelite::Model& model, namespace treelite { void -BranchAnnotator::Annotate(const Model& model, const DMatrix* dmat, +BranchAnnotator::Annotate(const Model& model_ptr, const DMatrix* dmat, int nthread, int verbose) { + const ModelImpl& model = dynamic_cast(model_ptr); + std::vector new_counts; std::vector counts_tloc; std::vector count_row_ptr; diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1850ef2e..8c452204 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -307,29 +307,23 @@ int TreeliteCompilerFree(CompilerHandle handle) { API_END(); } -int TreeliteLoadLightGBMModel(const char* filename, - ModelHandle* out) { +int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadLightGBMModel(filename, model.get()); + std::unique_ptr model = frontend::LoadLightGBMModel(filename); *out = static_cast(model.release()); API_END(); } -int TreeliteLoadXGBoostModel(const char* filename, - ModelHandle* out) { +int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadXGBoostModel(filename, model.get()); + std::unique_ptr model = frontend::LoadXGBoostModel(filename); *out = static_cast(model.release()); API_END(); } -int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, - ModelHandle* out) { +int TreeliteLoadXGBoostModelFromMemoryBuffer(const void* buf, size_t len, ModelHandle* out) { API_BEGIN(); - std::unique_ptr model{new Model()}; - frontend::LoadXGBoostModel(buf, len, model.get()); + std::unique_ptr model = frontend::LoadXGBoostModel(buf, len); *out = static_cast(model.release()); API_END(); } @@ -342,21 +336,21 @@ int TreeliteFreeModel(ModelHandle handle) { int TreeliteQueryNumTree(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); - *out = model_->trees.size(); + const auto* model_ = static_cast(handle); + *out = model_->GetNumTree(); API_END(); } int TreeliteQueryNumFeature(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); + const auto* model_ = static_cast(handle); *out = static_cast(model_->num_feature); API_END(); } int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t* out) { API_BEGIN(); - auto model_ = static_cast(handle); + const auto* model_ = static_cast(handle); *out = static_cast(model_->num_output_group); API_END(); } @@ -364,10 +358,10 @@ int TreeliteQueryNumOutputGroups(ModelHandle handle, size_t* out) { int TreeliteSetTreeLimit(ModelHandle handle, size_t limit) { API_BEGIN(); CHECK_GT(limit, 0) << "limit should be greater than 0!"; - auto model_ = static_cast(handle); - CHECK_GE(model_->trees.size(), limit) - << "Model contains less trees(" << model_->trees.size() << ") than limit"; - model_->trees.resize(limit); + auto* model_ = static_cast(handle); + const size_t num_tree = model_->GetNumTree(); + CHECK_GE(num_tree, limit) << "Model contains less trees(" << num_tree << ") than limit"; + model_->SetTreeLimit(limit); API_END(); } @@ -531,8 +525,7 @@ int TreeliteModelBuilderCommitModel(ModelBuilderHandle handle, API_BEGIN(); auto builder = static_cast(handle); CHECK(builder) << "Detected dangling reference to deleted ModelBuilder object"; - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); *out = static_cast(model.release()); API_END(); } diff --git a/src/compiler/ast/build.cc b/src/compiler/ast/build.cc index 19244d77..bf55bd4e 100644 --- a/src/compiler/ast/build.cc +++ b/src/compiler/ast/build.cc @@ -11,7 +11,7 @@ namespace compiler { DMLC_REGISTRY_FILE_TAG(build); -void ASTBuilder::BuildAST(const Model& model) { +void ASTBuilder::BuildAST(const ModelImpl& model) { this->output_vector_flag = (model.num_output_group > 1 && model.random_forest_flag); this->num_feature = model.num_feature; diff --git a/src/compiler/ast/builder.h b/src/compiler/ast/builder.h index 8046b15d..424939ba 100644 --- a/src/compiler/ast/builder.h +++ b/src/compiler/ast/builder.h @@ -30,7 +30,7 @@ class ASTBuilder { quantize_threshold_flag(false) {} /* \brief initially build AST from model */ - void BuildAST(const Model& model); + void BuildAST(const ModelImpl& model); /* \brief generate is_categorical[] array, which tells whether each feature is categorical or numerical */ std::vector GenerateIsCategoricalArray(); diff --git a/src/compiler/ast_native.cc b/src/compiler/ast_native.cc index 4a4175e2..7091a02b 100644 --- a/src/compiler/ast_native.cc +++ b/src/compiler/ast_native.cc @@ -62,7 +62,7 @@ class ASTNativeCompiler : public Compiler { files_.clear(); ASTBuilder builder; - builder.BuildAST(model); + builder.BuildAST(dynamic_cast(model)); if (builder.FoldCode(param.code_folding_req) || param.quantize > 0) { // is_categorical[i] : is i-th feature categorical? diff --git a/src/compiler/failsafe.cc b/src/compiler/failsafe.cc index 2ab16bcd..f1e95628 100644 --- a/src/compiler/failsafe.cc +++ b/src/compiler/failsafe.cc @@ -136,7 +136,7 @@ const char* arrays_template = R"TREELITETEMPLATE( // nodes[]: stores nodes from all decision trees // nodes_row_ptr[]: marks bounaries between decision trees. The nodes belonging to Tree [i] are // found in nodes[nodes_row_ptr[i]:nodes_row_ptr[i+1]] -inline std::pair FormatNodesArray(const treelite::Model& model) { +inline std::pair FormatNodesArray(const treelite::ModelImpl& model) { treelite::compiler::common_util::ArrayFormatter nodes(100, 2); treelite::compiler::common_util::ArrayFormatter nodes_row_ptr(100, 2); int node_count = 0; @@ -172,7 +172,8 @@ inline std::pair FormatNodesArray(const treelite::Mode } // Variant of FormatNodesArray(), where nodes[] array is dumped as an ELF binary -inline std::pair, std::string> FormatNodesArrayELF(const treelite::Model& model) { +inline std::pair, std::string> FormatNodesArrayELF( + const treelite::ModelImpl& model) { std::vector nodes_elf; treelite::compiler::AllocateELFHeader(&nodes_elf); @@ -208,7 +209,7 @@ inline std::pair, std::string> FormatNodesArrayELF(const treel // Get the comparison op used in the tree ensemble model // If splits have more than one op, throw an error -inline std::string GetCommonOp(const treelite::Model& model) { +inline std::string GetCommonOp(const treelite::ModelImpl& model) { std::set ops; for (const auto& tree : model.trees) { for (int nid = 0; nid < tree.num_nodes; ++nid) { @@ -262,7 +263,9 @@ class FailSafeCompiler : public Compiler { } } - CompiledModel Compile(const Model& model) override { + CompiledModel Compile(const Model& model_ptr) override { + const auto& model = dynamic_cast(model_ptr); + CompiledModel cm; cm.backend = "native"; diff --git a/src/compiler/native/pred_transform.h b/src/compiler/native/pred_transform.h index 9703b044..9a2b2fbf 100644 --- a/src/compiler/native/pred_transform.h +++ b/src/compiler/native/pred_transform.h @@ -115,7 +115,7 @@ inline std::string multiclass_ova(const Model& model) { const float alpha = model.param.sigmoid_alpha; CHECK_GT(alpha, 0.0f) << "multiclass_ova: alpha must be strictly positive"; return fmt::format( -R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ + R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ const float alpha = (float){alpha}; const int num_class = {num_class}; for (int k = 0; k < num_class; ++k) {{ @@ -123,7 +123,7 @@ R"TREELITETEMPLATE(static inline size_t pred_transform(float* pred) {{ }} return (size_t)num_class; }})TREELITETEMPLATE", - "num_class"_a = model.num_output_group, "alpha"_a = alpha); + "num_class"_a = model.num_output_group, "alpha"_a = alpha); } } // namespace pred_transform diff --git a/src/frontend/builder.cc b/src/frontend/builder.cc index 5b7ff227..b571a3db 100644 --- a/src/frontend/builder.cc +++ b/src/frontend/builder.cc @@ -310,9 +310,10 @@ ModelBuilder::DeleteTree(int index) { trees.erase(trees.begin() + index); } -void -ModelBuilder::CommitModel(Model* out_model) { - Model model; +std::unique_ptr +ModelBuilder::CommitModel() { + std::unique_ptr model_ptr = Model::Create(); + ModelImpl& model = *dynamic_cast(model_ptr.get()); model.num_feature = pimpl->num_feature; model.num_output_group = pimpl->num_output_group; model.random_forest_flag = pimpl->random_forest_flag; @@ -405,7 +406,7 @@ ModelBuilder::CommitModel(Model* out_model) { } else { LOG(FATAL) << "Impossible thing happened: model has no leaf node!"; } - *out_model = std::move(model); + return model_ptr; } } // namespace frontend diff --git a/src/frontend/lightgbm.cc b/src/frontend/lightgbm.cc index f907ef20..e14a3174 100644 --- a/src/frontend/lightgbm.cc +++ b/src/frontend/lightgbm.cc @@ -14,7 +14,7 @@ namespace { -treelite::Model ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(dmlc::Stream* fi); } // anonymous namespace @@ -23,9 +23,9 @@ namespace frontend { DMLC_REGISTRY_FILE_TAG(lightgbm); -void LoadLightGBMModel(const char *filename, Model* out) { +std::unique_ptr LoadLightGBMModel(const char *filename) { std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - *out = std::move(ParseStream(fi.get())); + return ParseStream(fi.get()); } } // namespace frontend @@ -253,7 +253,7 @@ inline std::vector LoadText(dmlc::Stream* fi) { return lines; } -inline treelite::Model ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(dmlc::Stream* fi) { std::vector lgb_trees_; int max_feature_idx_; int num_tree_per_iteration_; @@ -436,17 +436,18 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } /* 2. Export model */ - treelite::Model model; - model.num_feature = max_feature_idx_ + 1; - model.num_output_group = num_tree_per_iteration_; - if (model.num_output_group > 1) { + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast(model_ptr.get()); + model->num_feature = max_feature_idx_ + 1; + model->num_output_group = num_tree_per_iteration_; + if (model->num_output_group > 1) { // multiclass classification with gradient boosted trees CHECK(!average_output_) << "Ill-formed LightGBM model file: cannot use random forest mode " << "for multi-class classification"; - model.random_forest_flag = false; + model->random_forest_flag = false; } else { - model.random_forest_flag = average_output_; + model->random_forest_flag = average_output_; } // set correct prediction transform function, depending on objective function @@ -462,10 +463,10 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { break; } } - CHECK(num_class >= 0 && num_class == model.num_output_group) + CHECK(num_class >= 0 && num_class == model->num_output_group) << "Ill-formed LightGBM model file: not a valid multiclass objective"; - std::strncpy(model.param.pred_transform, "softmax", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); } else if (obj_name_ == "multiclassova") { // validate num_class and alpha parameters int num_class = -1; @@ -484,12 +485,13 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - CHECK(num_class >= 0 && num_class == model.num_output_group + CHECK(num_class >= 0 && num_class == model->num_output_group && alpha > 0.0f) << "Ill-formed LightGBM model file: not a valid multiclassova objective"; - std::strncpy(model.param.pred_transform, "multiclass_ova", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = alpha; + std::strncpy(model->param.pred_transform, "multiclass_ova", + sizeof(model->param.pred_transform)); + model->param.sigmoid_alpha = alpha; } else if (obj_name_ == "binary") { // validate alpha parameter float alpha = -1.0f; @@ -505,22 +507,23 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { CHECK_GT(alpha, 0.0f) << "Ill-formed LightGBM model file: not a valid binary objective"; - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = alpha; + std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); + model->param.sigmoid_alpha = alpha; } else if (obj_name_ == "xentropy" || obj_name_ == "cross_entropy") { - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = 1.0f; + std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); + model->param.sigmoid_alpha = 1.0f; } else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") { - std::strncpy(model.param.pred_transform, "logarithm_one_plus_exp", - sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp", + sizeof(model->param.pred_transform)); } else { - std::strncpy(model.param.pred_transform, "identity", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "identity", + sizeof(model->param.pred_transform)); } // traverse trees for (const auto& lgb_tree : lgb_trees_) { - model.trees.emplace_back(); - treelite::Tree& tree = model.trees.back(); + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -583,8 +586,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { } } } - LOG(INFO) << "model.num_tree = " << model.trees.size(); - return model; + return model_ptr; } } // anonymous namespace diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc index aec5679b..e28f0dd4 100644 --- a/src/frontend/xgboost.cc +++ b/src/frontend/xgboost.cc @@ -16,7 +16,7 @@ namespace { -treelite::Model ParseStream(dmlc::Stream* fi); +inline std::unique_ptr ParseStream(dmlc::Stream* fi); } // anonymous namespace @@ -25,14 +25,14 @@ namespace frontend { DMLC_REGISTRY_FILE_TAG(xgboost); -void LoadXGBoostModel(const char* filename, Model* out) { +std::unique_ptr LoadXGBoostModel(const char* filename) { std::unique_ptr fi(dmlc::Stream::Create(filename, "r")); - *out = std::move(ParseStream(fi.get())); + return ParseStream(fi.get()); } -void LoadXGBoostModel(const void* buf, size_t len, Model* out) { +std::unique_ptr LoadXGBoostModel(const void* buf, size_t len) { dmlc::MemoryFixedSizeStream fs(const_cast(buf), len); - *out = std::move(ParseStream(&fs)); + return ParseStream(&fs); } } // namespace frontend @@ -330,7 +330,7 @@ class XGBTree { } }; -inline treelite::Model ParseStream(dmlc::Stream* fi) { +inline std::unique_ptr ParseStream(dmlc::Stream* fi) { std::vector xgb_trees_; LearnerModelParam mparam_; // model parameter GBTreeModelParam gbm_param_; // GBTree training parameter @@ -392,44 +392,48 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { bool need_transform_to_margin = mparam_.major_version >= 1; /* 2. Export model */ - treelite::Model model; - model.num_feature = mparam_.num_feature; - model.num_output_group = std::max(mparam_.num_class, 1); - model.random_forest_flag = false; + std::unique_ptr model_ptr = treelite::Model::Create(); + auto* model = dynamic_cast(model_ptr.get()); + model->num_feature = mparam_.num_feature; + model->num_output_group = std::max(mparam_.num_class, 1); + model->random_forest_flag = false; // set global bias - model.param.global_bias = static_cast(mparam_.base_score); + model->param.global_bias = static_cast(mparam_.base_score); std::vector exponential_family { "count:poisson", "reg:gamma", "reg:tweedie" }; if (need_transform_to_margin) { if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - model.param.global_bias = ProbToMargin::Sigmoid(model.param.global_bias); + model->param.global_bias = ProbToMargin::Sigmoid(model->param.global_bias); } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - model.param.global_bias = ProbToMargin::Exponential(model.param.global_bias); + model->param.global_bias = ProbToMargin::Exponential(model->param.global_bias); } } // set correct prediction transform function, depending on objective function if (name_obj_ == "multi:softmax") { - std::strncpy(model.param.pred_transform, "max_index", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "max_index", + sizeof(model->param.pred_transform)); } else if (name_obj_ == "multi:softprob") { - std::strncpy(model.param.pred_transform, "softmax", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "softmax", sizeof(model->param.pred_transform)); } else if (name_obj_ == "reg:logistic" || name_obj_ == "binary:logistic") { - std::strncpy(model.param.pred_transform, "sigmoid", sizeof(model.param.pred_transform)); - model.param.sigmoid_alpha = 1.0f; + std::strncpy(model->param.pred_transform, "sigmoid", sizeof(model->param.pred_transform)); + model->param.sigmoid_alpha = 1.0f; } else if (std::find(exponential_family.cbegin() , exponential_family.cend(), name_obj_) != exponential_family.cend()) { - std::strncpy(model.param.pred_transform, "exponential", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "exponential", + sizeof(model->param.pred_transform)); } else { - std::strncpy(model.param.pred_transform, "identity", sizeof(model.param.pred_transform)); + std::strncpy(model->param.pred_transform, "identity", + sizeof(model->param.pred_transform)); } // traverse trees for (const auto& xgb_tree : xgb_trees_) { - model.trees.emplace_back(); - treelite::Tree& tree = model.trees.back(); + model->trees.emplace_back(); + treelite::Tree& tree = model->trees.back(); tree.Init(); // assign node ID's so that a breadth-wise traversal would yield @@ -458,7 +462,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) { tree.SetSumHess(new_id, stat.sum_hess); } } - return model; + return model_ptr; } } // anonymous namespace diff --git a/src/reference_serializer.cc b/src/reference_serializer.cc index 6d60fb9b..5ecaa568 100644 --- a/src/reference_serializer.cc +++ b/src/reference_serializer.cc @@ -54,7 +54,7 @@ void Tree::ReferenceSerialize(dmlc::Stream* fo) const { CHECK_EQ(left_categories_offset_.Back(), left_categories_.Size()); } -void Model::ReferenceSerialize(dmlc::Stream* fo) const { +void ModelImpl::ReferenceSerialize(dmlc::Stream* fo) const { fo->Write(num_feature); fo->Write(num_output_group); fo->Write(random_forest_flag); diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 5b0d4685..5a8d5ab4 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,7 +1,7 @@ add_executable(treelite_cpp_test) set_target_properties(treelite_cpp_test PROPERTIES - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED ON) target_link_libraries(treelite_cpp_test PRIVATE objtreelite objtreelite_runtime objtreelite_common GTest::GTest) diff --git a/tests/cpp/test_serializer.cc b/tests/cpp/test_serializer.cc index 3dc494fb..a3bf1dbf 100644 --- a/tests/cpp/test_serializer.cc +++ b/tests/cpp/test_serializer.cc @@ -23,8 +23,7 @@ inline std::string TreeliteToBytes(treelite::Model* model) { inline void TestRoundTrip(treelite::Model* model) { auto buffer = model->GetPyBuffer(); - std::unique_ptr received_model{new treelite::Model()}; - received_model->InitFromPyBuffer(buffer); + std::unique_ptr received_model = treelite::Model::CreateFromPyBuffer(buffer); ASSERT_EQ(TreeliteToBytes(model), TreeliteToBytes(received_model.get())); } @@ -47,8 +46,7 @@ TEST(PyBufferInterfaceRoundTrip, TreeStump) { tree->SetLeafNode(2, 1.0f); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -66,8 +64,7 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpLeafVec) { tree->SetLeafVectorNode(2, {1.0f, -1.0f}); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -85,8 +82,7 @@ TEST(PyBufferInterfaceRoundTrip, TreeStumpCategoricalSplit) { tree->SetLeafNode(2, 1.0f); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -112,8 +108,7 @@ TEST(PyBufferInterfaceRoundTrip, TreeDepth2) { builder->InsertTree(tree.get()); } - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } @@ -144,8 +139,7 @@ TEST(PyBufferInterfaceRoundTrip, DeepFullTree) { tree->SetRootNode(0); builder->InsertTree(tree.get()); - std::unique_ptr model{new Model()}; - builder->CommitModel(model.get()); + std::unique_ptr model = builder->CommitModel(); TestRoundTrip(model.get()); } diff --git a/tests/example_app/CMakeLists.txt b/tests/example_app/CMakeLists.txt index 574442b0..bb457c32 100644 --- a/tests/example_app/CMakeLists.txt +++ b/tests/example_app/CMakeLists.txt @@ -14,6 +14,6 @@ add_executable(example example.cc) target_link_libraries(example PRIVATE treelite::treelite_static treelite::treelite_runtime_static) set_target_properties(example PROPERTIES - CXX_STANDARD 11 + CXX_STANDARD 14 CXX_STANDARD_REQUIRED YES ) diff --git a/tests/example_app/example.cc b/tests/example_app/example.cc index 5966ee04..44b27318 100644 --- a/tests/example_app/example.cc +++ b/tests/example_app/example.cc @@ -22,14 +22,13 @@ int main(void) { std::unique_ptr builder{new ModelBuilder(2, 1, false)}; builder->InsertTree(tree.get()); - treelite::Model model; - builder->CommitModel(&model); - std::cout << model.trees.size() << std::endl; + std::unique_ptr model = builder->CommitModel(); + std::cout << model->GetNumTree() << std::endl; treelite::compiler::CompilerParam param; param.Init(std::map{}); std::unique_ptr compiler{treelite::Compiler::Create("ast_native", param)}; - treelite::compiler::CompiledModel cm = compiler->Compile(model); + treelite::compiler::CompiledModel cm = compiler->Compile(*model.get()); for (const auto& kv : cm.files) { std::cout << "=================" << kv.first << "=================" << std::endl; std::cout << kv.second.content << std::endl; From 8138e130d3e9d6d71ba5c569d2742a8388f6a876 Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Fri, 18 Sep 2020 15:08:59 -0700 Subject: [PATCH 37/38] Inline GetThresholdType() and GetLeafOutputType() --- include/treelite/tree.h | 8 ++++++-- include/treelite/tree_impl.h | 10 ---------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/include/treelite/tree.h b/include/treelite/tree.h index 9d877961..9cf9d530 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -496,8 +496,12 @@ class Model { template inline static std::unique_ptr Create(); inline static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type); - inline TypeInfo GetThresholdType() const; - inline TypeInfo GetLeafOutputType() const; + inline TypeInfo GetThresholdType() const { + return threshold_type_; + } + inline TypeInfo GetLeafOutputType() const { + return leaf_output_type_; + } template inline auto Dispatch(Func func); template diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 7faee181..7697b449 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -587,16 +587,6 @@ Tree::SetLeafVector( node.split_type_ = SplitFeatureType::kNone; } -inline TypeInfo -Model::GetThresholdType() const { - return threshold_type_; -} - -inline TypeInfo -Model::GetLeafOutputType() const { - return leaf_output_type_; -} - template inline std::unique_ptr Model::Create() { From 269dd2ec4342f846762165b9cd071170f6cde9b5 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 25 Sep 2020 01:04:03 -0700 Subject: [PATCH 38/38] Create template classes for model representations to support multiple data types (Part of #196) (#198) * Create template classes for model representations to support multiple data types * Update include/treelite/tree.h Co-authored-by: William Hicks * Address review comments from @canonizer * Address more comments from @canonizer Co-authored-by: William Hicks Co-authored-by: William Hicks Co-authored-by: Andy Adinets --- include/treelite/base.h | 7 +- include/treelite/tree.h | 58 ++++++++++---- include/treelite/tree_impl.h | 145 +++++++++++++++++++++++++++++------ include/treelite/typeinfo.h | 2 +- 4 files changed, 170 insertions(+), 42 deletions(-) diff --git a/include/treelite/base.h b/include/treelite/base.h index f9f352a0..f9f0e72d 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -8,9 +8,11 @@ #define TREELITE_BASE_H_ #include +#include #include #include #include +#include "./typeinfo.h" namespace treelite { @@ -33,7 +35,7 @@ enum class Operator : int8_t { extern const std::unordered_map optable; /*! - * \brief get string representation of comparsion operator + * \brief get string representation of comparison operator * \param op comparison operator * \return string representation */ @@ -56,7 +58,8 @@ inline std::string OpName(Operator op) { * \param rhs float on the right hand side * \return whether [lhs] [op] [rhs] is true or not */ -inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) { +template +inline bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs) { switch (op) { case Operator::kEQ: return lhs == rhs; case Operator::kLT: return lhs < rhs; diff --git a/include/treelite/tree.h b/include/treelite/tree.h index ba1d89c0..26999d8f 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -78,6 +78,7 @@ class ContiguousArray { }; /*! \brief in-memory representation of a decision tree */ +template class Tree { public: /*! \brief tree node */ @@ -87,8 +88,8 @@ class Tree { inline void Init(); /*! \brief store either leaf value or decision threshold */ union Info { - tl_float leaf_value; // for leaf nodes - tl_float threshold; // for non-leaf nodes + LeafOutputType leaf_value; // for leaf nodes + ThresholdType threshold; // for non-leaf nodes }; /*! \brief pointer to left and right children */ int32_t cleft_, cright_; @@ -137,7 +138,19 @@ class Tree { }; static_assert(std::is_pod::value, "Node must be a POD type"); - static_assert(sizeof(Node) == 48, "Node must be 48 bytes"); + static_assert(std::is_same::value + || std::is_same::value, + "ThresholdType must be either float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "LeafOutputType must be one of uint32_t, float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value, + "Unsupported combination of ThresholdType and LeafOutputType"); + static_assert((std::is_same::value && sizeof(Node) == 48) + || (std::is_same::value && sizeof(Node) == 56), + "Node size incorrect"); Tree() = default; ~Tree() = default; @@ -146,6 +159,7 @@ class Tree { Tree(Tree&&) noexcept = default; Tree& operator=(Tree&&) noexcept = default; + inline const char* GetFormatStringForNode(); inline void GetPyBuffer(std::vector* dest); inline void InitFromPyBuffer(std::vector::iterator begin, std::vector::iterator end); @@ -153,7 +167,7 @@ class Tree { private: // vector of nodes ContiguousArray nodes_; - ContiguousArray leaf_vector_; + ContiguousArray leaf_vector_; ContiguousArray leaf_vector_offset_; ContiguousArray left_categories_; ContiguousArray left_categories_offset_; @@ -225,19 +239,19 @@ class Tree { * \brief get leaf value of the leaf node * \param nid ID of node being queried */ - inline tl_float LeafValue(int nid) const { + inline LeafOutputType LeafValue(int nid) const { return (nodes_[nid].info_).leaf_value; } /*! * \brief get leaf vector of the leaf node; useful for multi-class random forest classifier * \param nid ID of node being queried */ - inline std::vector LeafVector(int nid) const { + inline std::vector LeafVector(int nid) const { if (nid > leaf_vector_offset_.Size()) { throw std::runtime_error("nid too large"); } - return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], - &leaf_vector_[leaf_vector_offset_[nid + 1]]); + return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], + &leaf_vector_[leaf_vector_offset_[nid + 1]]); } /*! * \brief tests whether the leaf node has a non-empty leaf vector @@ -253,7 +267,7 @@ class Tree { * \brief get threshold of the node * \param nid ID of node being queried */ - inline tl_float Threshold(int nid) const { + inline ThresholdType Threshold(int nid) const { return (nodes_[nid].info_).threshold; } /*! @@ -346,7 +360,7 @@ class Tree { * \param cmp comparison operator to compare between feature value and * threshold */ - inline void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, + inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp); /*! * \brief create a categorical split @@ -365,13 +379,13 @@ class Tree { * \param nid ID of node being updated * \param value leaf value */ - inline void SetLeaf(int nid, tl_float value); + inline void SetLeaf(int nid, LeafOutputType value); /*! * \brief set the leaf vector of the node; useful for multi-class random forest classifier * \param nid ID of node being updated * \param leaf_vector leaf vector */ - inline void SetLeafVector(int nid, const std::vector& leaf_vector); + inline void SetLeafVector(int nid, const std::vector& leaf_vector); /*! * \brief set the hessian sum of the node * \param nid ID of node being updated @@ -474,12 +488,25 @@ class Model { /*! \brief disable copy; use default move */ Model() = default; virtual ~Model() = default; - inline static std::unique_ptr Create(); Model(const Model&) = delete; Model& operator=(const Model&) = delete; Model(Model&&) = default; Model& operator=(Model&&) = default; + template + inline static std::unique_ptr Create(); + inline static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type); + inline TypeInfo GetThresholdType() const { + return threshold_type_; + } + inline TypeInfo GetLeafOutputType() const { + return leaf_output_type_; + } + template + inline auto Dispatch(Func func); + template + inline auto Dispatch(Func func) const; + virtual size_t GetNumTree() const = 0; virtual void SetTreeLimit(size_t limit) = 0; virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0; @@ -502,16 +529,19 @@ class Model { ModelParam param; private: + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; // Internal functions for serialization virtual void GetPyBuffer(std::vector* dest) = 0; virtual void InitFromPyBuffer(std::vector::iterator begin, std::vector::iterator end) = 0; }; +template class ModelImpl : public Model { public: /*! \brief member trees */ - std::vector trees; + std::vector> trees; /*! \brief disable copy; use default move */ ModelImpl() = default; diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 85ac40dc..ab348885 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -333,6 +334,11 @@ inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) { return GetPyBufferFromScalar(static_cast(scalar), format, sizeof(T)); } +inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) { + using T = std::underlying_type::type; + return GetPyBufferFromScalar(reinterpret_cast(scalar), InferFormatString()); +} + template inline PyBufferFrame GetPyBufferFromScalar(T* scalar) { static_assert(std::is_arithmetic::value, @@ -349,6 +355,18 @@ inline void InitArrayFromPyBuffer(ContiguousArray* vec, PyBufferFrame buffer) vec->UseForeignBuffer(buffer.buf, buffer.nitem); } +inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) { + using T = std::underlying_type::type; + if (sizeof(T) != buffer.itemsize) { + throw std::runtime_error("Incorrect itemsize"); + } + if (buffer.nitem != 1) { + throw std::runtime_error("nitem must be 1 for a scalar"); + } + T* t = static_cast(buffer.buf); + *scalar = static_cast(*t); +} + template inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { if (sizeof(T) != buffer.itemsize) { @@ -361,21 +379,33 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { *scalar = *t; } +template +inline const char* +Tree::GetFormatStringForNode() { + if (std::is_same::value) { + return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}"; + } else { + return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}"; + } +} + constexpr size_t kNumFramePerTree = 6; +template inline void -Tree::GetPyBuffer(std::vector* dest) { +Tree::GetPyBuffer(std::vector* dest) { dest->push_back(GetPyBufferFromScalar(&num_nodes)); - dest->push_back(GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}")); + dest->push_back(GetPyBufferFromArray(&nodes_, GetFormatStringForNode())); dest->push_back(GetPyBufferFromArray(&leaf_vector_)); dest->push_back(GetPyBufferFromArray(&leaf_vector_offset_)); dest->push_back(GetPyBufferFromArray(&left_categories_)); dest->push_back(GetPyBufferFromArray(&left_categories_offset_)); } +template inline void -Tree::InitFromPyBuffer(std::vector::iterator begin, - std::vector::iterator end) { +Tree::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { if (std::distance(begin, end) != kNumFramePerTree) { throw std::runtime_error("Wrong number of frames specified"); } @@ -390,11 +420,12 @@ Tree::InitFromPyBuffer(std::vector::iterator begin, InitArrayFromPyBuffer(&left_categories_offset_, *begin++); } -inline void Tree::Node::Init() { +template +inline void Tree::Node::Init() { cleft_ = cright_ = -1; sindex_ = 0; - info_.leaf_value = 0.0f; - info_.threshold = 0.0f; + info_.leaf_value = static_cast(0); + info_.threshold = static_cast(0); data_count_ = 0; sum_hess_ = gain_ = 0.0; missing_category_to_zero_ = false; @@ -403,8 +434,9 @@ inline void Tree::Node::Init() { cmp_ = Operator::kNone; } +template inline int -Tree::AllocNode() { +Tree::AllocNode() { int nd = num_nodes++; if (nodes_.Size() != static_cast(nd)) { throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes"); @@ -418,8 +450,9 @@ Tree::AllocNode() { return nd; } +template inline void -Tree::Init() { +Tree::Init() { num_nodes = 1; leaf_vector_.Clear(); leaf_vector_offset_.Resize(2, 0); @@ -427,19 +460,21 @@ Tree::Init() { left_categories_offset_.Resize(2, 0); nodes_.Resize(1); nodes_[0].Init(); - SetLeaf(0, 0.0f); + SetLeaf(0, static_cast(0)); } +template inline void -Tree::AddChilds(int nid) { +Tree::AddChilds(int nid) { const int cleft = this->AllocNode(); const int cright = this->AllocNode(); nodes_[nid].cleft_ = cleft; nodes_[nid].cright_ = cright; } +template inline std::vector -Tree::GetCategoricalFeatures() const { +Tree::GetCategoricalFeatures() const { std::unordered_map tmp; for (int nid = 0; nid < num_nodes; ++nid) { const SplitFeatureType type = SplitType(nid); @@ -466,9 +501,10 @@ Tree::GetCategoricalFeatures() const { return result; } +template inline void -Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, - bool default_left, Operator cmp) { +Tree::SetNumericalSplit( + int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) { Node& node = nodes_[nid]; if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); @@ -480,10 +516,11 @@ Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, node.split_type_ = SplitFeatureType::kNumerical; } +template inline void -Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, - bool missing_category_to_zero, - const std::vector& node_left_categories) { +Tree::SetCategoricalSplit( + int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, + const std::vector& node_left_categories) { if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); } @@ -513,8 +550,9 @@ Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, node.missing_category_to_zero_ = missing_category_to_zero; } +template inline void -Tree::SetLeaf(int nid, tl_float value) { +Tree::SetLeaf(int nid, LeafOutputType value) { Node& node = nodes_[nid]; (node.info_).leaf_value = value; node.cleft_ = -1; @@ -522,8 +560,10 @@ Tree::SetLeaf(int nid, tl_float value) { node.split_type_ = SplitFeatureType::kNone; } +template inline void -Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { +Tree::SetLeafVector( + int nid, const std::vector& node_leaf_vector) { const size_t end_oft = leaf_vector_offset_.Back(); const size_t new_end_oft = end_oft + node_leaf_vector.size(); if (end_oft != leaf_vector_.Size()) { @@ -547,28 +587,82 @@ Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { node.split_type_ = SplitFeatureType::kNone; } +template inline std::unique_ptr Model::Create() { - std::unique_ptr model = std::make_unique(); + std::unique_ptr model = std::make_unique>(); + model->threshold_type_ = TypeToInfo(); + model->leaf_output_type_ = TypeToInfo(); return model; } +template +class ModelCreateImpl { + public: + inline static std::unique_ptr Dispatch() { + return Model::Create(); + } +}; + +inline std::unique_ptr +Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { + return DispatchWithModelTypes(threshold_type, leaf_output_type); +} + +template +class ModelDispatchImpl { + public: + template + inline static auto Dispatch(Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } + + template + inline static auto Dispatch(const Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } +}; + +template +inline auto +Model::Dispatch(Func func) { + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); +} + +template +inline auto +Model::Dispatch(Func func) const { + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); +} + inline std::vector Model::GetPyBuffer() { std::vector buffer; + buffer.push_back(GetPyBufferFromScalar(&threshold_type_)); + buffer.push_back(GetPyBufferFromScalar(&leaf_output_type_)); this->GetPyBuffer(&buffer); return buffer; } inline std::unique_ptr Model::CreateFromPyBuffer(std::vector frames) { - std::unique_ptr model = Model::Create(); - model->InitFromPyBuffer(frames.begin(), frames.end()); + using TypeInfoInt = std::underlying_type::type; + TypeInfo threshold_type, leaf_output_type; + if (frames.size() < 2) { + throw std::runtime_error("Insufficient number of frames: there must be at least two"); + } + InitScalarFromPyBuffer(&threshold_type, frames[0]); + InitScalarFromPyBuffer(&leaf_output_type, frames[1]); + + std::unique_ptr model = Model::Create(threshold_type, leaf_output_type); + model->InitFromPyBuffer(frames.begin() + 2, frames.end()); return model; } + +template inline void -ModelImpl::GetPyBuffer(std::vector* dest) { +ModelImpl::GetPyBuffer(std::vector* dest) { /* Header */ dest->push_back(GetPyBufferFromScalar(&num_feature)); dest->push_back(GetPyBufferFromScalar(&num_output_group)); @@ -577,13 +671,14 @@ ModelImpl::GetPyBuffer(std::vector* dest) { ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); /* Body */ - for (Tree& tree : trees) { + for (Tree& tree : trees) { tree.GetPyBuffer(dest); } } +template inline void -ModelImpl::InitFromPyBuffer( +ModelImpl::InitFromPyBuffer( std::vector::iterator begin, std::vector::iterator end) { const size_t num_frame = std::distance(begin, end); /* Header */ diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index 7b1337ea..01ea7cfa 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -58,7 +58,7 @@ inline std::string TypeInfoToString(treelite::TypeInfo type) { * \return TypeInfo corresponding to the template type arg */ template -inline TypeInfo InferTypeInfoOf() { +inline TypeInfo TypeToInfo() { if (std::is_same::value) { return TypeInfo::kUInt32; } else if (std::is_same::value) {