From 269dd2ec4342f846762165b9cd071170f6cde9b5 Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Fri, 25 Sep 2020 01:04:03 -0700 Subject: [PATCH] Create template classes for model representations to support multiple data types (Part of #196) (#198) * Create template classes for model representations to support multiple data types * Update include/treelite/tree.h Co-authored-by: William Hicks * Address review comments from @canonizer * Address more comments from @canonizer Co-authored-by: William Hicks Co-authored-by: William Hicks Co-authored-by: Andy Adinets --- include/treelite/base.h | 7 +- include/treelite/tree.h | 58 ++++++++++---- include/treelite/tree_impl.h | 145 +++++++++++++++++++++++++++++------ include/treelite/typeinfo.h | 2 +- 4 files changed, 170 insertions(+), 42 deletions(-) diff --git a/include/treelite/base.h b/include/treelite/base.h index f9f352a0..f9f0e72d 100644 --- a/include/treelite/base.h +++ b/include/treelite/base.h @@ -8,9 +8,11 @@ #define TREELITE_BASE_H_ #include +#include #include #include #include +#include "./typeinfo.h" namespace treelite { @@ -33,7 +35,7 @@ enum class Operator : int8_t { extern const std::unordered_map optable; /*! - * \brief get string representation of comparsion operator + * \brief get string representation of comparison operator * \param op comparison operator * \return string representation */ @@ -56,7 +58,8 @@ inline std::string OpName(Operator op) { * \param rhs float on the right hand side * \return whether [lhs] [op] [rhs] is true or not */ -inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) { +template +inline bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs) { switch (op) { case Operator::kEQ: return lhs == rhs; case Operator::kLT: return lhs < rhs; diff --git a/include/treelite/tree.h b/include/treelite/tree.h index ba1d89c0..26999d8f 100644 --- a/include/treelite/tree.h +++ b/include/treelite/tree.h @@ -78,6 +78,7 @@ class ContiguousArray { }; /*! \brief in-memory representation of a decision tree */ +template class Tree { public: /*! \brief tree node */ @@ -87,8 +88,8 @@ class Tree { inline void Init(); /*! \brief store either leaf value or decision threshold */ union Info { - tl_float leaf_value; // for leaf nodes - tl_float threshold; // for non-leaf nodes + LeafOutputType leaf_value; // for leaf nodes + ThresholdType threshold; // for non-leaf nodes }; /*! \brief pointer to left and right children */ int32_t cleft_, cright_; @@ -137,7 +138,19 @@ class Tree { }; static_assert(std::is_pod::value, "Node must be a POD type"); - static_assert(sizeof(Node) == 48, "Node must be 48 bytes"); + static_assert(std::is_same::value + || std::is_same::value, + "ThresholdType must be either float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "LeafOutputType must be one of uint32_t, float32 or float64"); + static_assert(std::is_same::value + || std::is_same::value, + "Unsupported combination of ThresholdType and LeafOutputType"); + static_assert((std::is_same::value && sizeof(Node) == 48) + || (std::is_same::value && sizeof(Node) == 56), + "Node size incorrect"); Tree() = default; ~Tree() = default; @@ -146,6 +159,7 @@ class Tree { Tree(Tree&&) noexcept = default; Tree& operator=(Tree&&) noexcept = default; + inline const char* GetFormatStringForNode(); inline void GetPyBuffer(std::vector* dest); inline void InitFromPyBuffer(std::vector::iterator begin, std::vector::iterator end); @@ -153,7 +167,7 @@ class Tree { private: // vector of nodes ContiguousArray nodes_; - ContiguousArray leaf_vector_; + ContiguousArray leaf_vector_; ContiguousArray leaf_vector_offset_; ContiguousArray left_categories_; ContiguousArray left_categories_offset_; @@ -225,19 +239,19 @@ class Tree { * \brief get leaf value of the leaf node * \param nid ID of node being queried */ - inline tl_float LeafValue(int nid) const { + inline LeafOutputType LeafValue(int nid) const { return (nodes_[nid].info_).leaf_value; } /*! * \brief get leaf vector of the leaf node; useful for multi-class random forest classifier * \param nid ID of node being queried */ - inline std::vector LeafVector(int nid) const { + inline std::vector LeafVector(int nid) const { if (nid > leaf_vector_offset_.Size()) { throw std::runtime_error("nid too large"); } - return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], - &leaf_vector_[leaf_vector_offset_[nid + 1]]); + return std::vector(&leaf_vector_[leaf_vector_offset_[nid]], + &leaf_vector_[leaf_vector_offset_[nid + 1]]); } /*! * \brief tests whether the leaf node has a non-empty leaf vector @@ -253,7 +267,7 @@ class Tree { * \brief get threshold of the node * \param nid ID of node being queried */ - inline tl_float Threshold(int nid) const { + inline ThresholdType Threshold(int nid) const { return (nodes_[nid].info_).threshold; } /*! @@ -346,7 +360,7 @@ class Tree { * \param cmp comparison operator to compare between feature value and * threshold */ - inline void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, + inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp); /*! * \brief create a categorical split @@ -365,13 +379,13 @@ class Tree { * \param nid ID of node being updated * \param value leaf value */ - inline void SetLeaf(int nid, tl_float value); + inline void SetLeaf(int nid, LeafOutputType value); /*! * \brief set the leaf vector of the node; useful for multi-class random forest classifier * \param nid ID of node being updated * \param leaf_vector leaf vector */ - inline void SetLeafVector(int nid, const std::vector& leaf_vector); + inline void SetLeafVector(int nid, const std::vector& leaf_vector); /*! * \brief set the hessian sum of the node * \param nid ID of node being updated @@ -474,12 +488,25 @@ class Model { /*! \brief disable copy; use default move */ Model() = default; virtual ~Model() = default; - inline static std::unique_ptr Create(); Model(const Model&) = delete; Model& operator=(const Model&) = delete; Model(Model&&) = default; Model& operator=(Model&&) = default; + template + inline static std::unique_ptr Create(); + inline static std::unique_ptr Create(TypeInfo threshold_type, TypeInfo leaf_output_type); + inline TypeInfo GetThresholdType() const { + return threshold_type_; + } + inline TypeInfo GetLeafOutputType() const { + return leaf_output_type_; + } + template + inline auto Dispatch(Func func); + template + inline auto Dispatch(Func func) const; + virtual size_t GetNumTree() const = 0; virtual void SetTreeLimit(size_t limit) = 0; virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0; @@ -502,16 +529,19 @@ class Model { ModelParam param; private: + TypeInfo threshold_type_; + TypeInfo leaf_output_type_; // Internal functions for serialization virtual void GetPyBuffer(std::vector* dest) = 0; virtual void InitFromPyBuffer(std::vector::iterator begin, std::vector::iterator end) = 0; }; +template class ModelImpl : public Model { public: /*! \brief member trees */ - std::vector trees; + std::vector> trees; /*! \brief disable copy; use default move */ ModelImpl() = default; diff --git a/include/treelite/tree_impl.h b/include/treelite/tree_impl.h index 85ac40dc..ab348885 100644 --- a/include/treelite/tree_impl.h +++ b/include/treelite/tree_impl.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -333,6 +334,11 @@ inline PyBufferFrame GetPyBufferFromScalar(T* scalar, const char* format) { return GetPyBufferFromScalar(static_cast(scalar), format, sizeof(T)); } +inline PyBufferFrame GetPyBufferFromScalar(TypeInfo* scalar) { + using T = std::underlying_type::type; + return GetPyBufferFromScalar(reinterpret_cast(scalar), InferFormatString()); +} + template inline PyBufferFrame GetPyBufferFromScalar(T* scalar) { static_assert(std::is_arithmetic::value, @@ -349,6 +355,18 @@ inline void InitArrayFromPyBuffer(ContiguousArray* vec, PyBufferFrame buffer) vec->UseForeignBuffer(buffer.buf, buffer.nitem); } +inline void InitScalarFromPyBuffer(TypeInfo* scalar, PyBufferFrame buffer) { + using T = std::underlying_type::type; + if (sizeof(T) != buffer.itemsize) { + throw std::runtime_error("Incorrect itemsize"); + } + if (buffer.nitem != 1) { + throw std::runtime_error("nitem must be 1 for a scalar"); + } + T* t = static_cast(buffer.buf); + *scalar = static_cast(*t); +} + template inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { if (sizeof(T) != buffer.itemsize) { @@ -361,21 +379,33 @@ inline void InitScalarFromPyBuffer(T* scalar, PyBufferFrame buffer) { *scalar = *t; } +template +inline const char* +Tree::GetFormatStringForNode() { + if (std::is_same::value) { + return "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?xx}"; + } else { + return "T{=l=l=Lxxxx=d=Q=d=d=b=b=?=?=?=?xx}"; + } +} + constexpr size_t kNumFramePerTree = 6; +template inline void -Tree::GetPyBuffer(std::vector* dest) { +Tree::GetPyBuffer(std::vector* dest) { dest->push_back(GetPyBufferFromScalar(&num_nodes)); - dest->push_back(GetPyBufferFromArray(&nodes_, "T{=l=l=L=f=Q=d=d=b=b=?=?=?=?=H}")); + dest->push_back(GetPyBufferFromArray(&nodes_, GetFormatStringForNode())); dest->push_back(GetPyBufferFromArray(&leaf_vector_)); dest->push_back(GetPyBufferFromArray(&leaf_vector_offset_)); dest->push_back(GetPyBufferFromArray(&left_categories_)); dest->push_back(GetPyBufferFromArray(&left_categories_offset_)); } +template inline void -Tree::InitFromPyBuffer(std::vector::iterator begin, - std::vector::iterator end) { +Tree::InitFromPyBuffer( + std::vector::iterator begin, std::vector::iterator end) { if (std::distance(begin, end) != kNumFramePerTree) { throw std::runtime_error("Wrong number of frames specified"); } @@ -390,11 +420,12 @@ Tree::InitFromPyBuffer(std::vector::iterator begin, InitArrayFromPyBuffer(&left_categories_offset_, *begin++); } -inline void Tree::Node::Init() { +template +inline void Tree::Node::Init() { cleft_ = cright_ = -1; sindex_ = 0; - info_.leaf_value = 0.0f; - info_.threshold = 0.0f; + info_.leaf_value = static_cast(0); + info_.threshold = static_cast(0); data_count_ = 0; sum_hess_ = gain_ = 0.0; missing_category_to_zero_ = false; @@ -403,8 +434,9 @@ inline void Tree::Node::Init() { cmp_ = Operator::kNone; } +template inline int -Tree::AllocNode() { +Tree::AllocNode() { int nd = num_nodes++; if (nodes_.Size() != static_cast(nd)) { throw std::runtime_error("Invariant violated: nodes_ contains incorrect number of nodes"); @@ -418,8 +450,9 @@ Tree::AllocNode() { return nd; } +template inline void -Tree::Init() { +Tree::Init() { num_nodes = 1; leaf_vector_.Clear(); leaf_vector_offset_.Resize(2, 0); @@ -427,19 +460,21 @@ Tree::Init() { left_categories_offset_.Resize(2, 0); nodes_.Resize(1); nodes_[0].Init(); - SetLeaf(0, 0.0f); + SetLeaf(0, static_cast(0)); } +template inline void -Tree::AddChilds(int nid) { +Tree::AddChilds(int nid) { const int cleft = this->AllocNode(); const int cright = this->AllocNode(); nodes_[nid].cleft_ = cleft; nodes_[nid].cright_ = cright; } +template inline std::vector -Tree::GetCategoricalFeatures() const { +Tree::GetCategoricalFeatures() const { std::unordered_map tmp; for (int nid = 0; nid < num_nodes; ++nid) { const SplitFeatureType type = SplitType(nid); @@ -466,9 +501,10 @@ Tree::GetCategoricalFeatures() const { return result; } +template inline void -Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, - bool default_left, Operator cmp) { +Tree::SetNumericalSplit( + int nid, unsigned split_index, ThresholdType threshold, bool default_left, Operator cmp) { Node& node = nodes_[nid]; if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); @@ -480,10 +516,11 @@ Tree::SetNumericalSplit(int nid, unsigned split_index, tl_float threshold, node.split_type_ = SplitFeatureType::kNumerical; } +template inline void -Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, - bool missing_category_to_zero, - const std::vector& node_left_categories) { +Tree::SetCategoricalSplit( + int nid, unsigned split_index, bool default_left, bool missing_category_to_zero, + const std::vector& node_left_categories) { if (split_index >= ((1U << 31U) - 1)) { throw std::runtime_error("split_index too big"); } @@ -513,8 +550,9 @@ Tree::SetCategoricalSplit(int nid, unsigned split_index, bool default_left, node.missing_category_to_zero_ = missing_category_to_zero; } +template inline void -Tree::SetLeaf(int nid, tl_float value) { +Tree::SetLeaf(int nid, LeafOutputType value) { Node& node = nodes_[nid]; (node.info_).leaf_value = value; node.cleft_ = -1; @@ -522,8 +560,10 @@ Tree::SetLeaf(int nid, tl_float value) { node.split_type_ = SplitFeatureType::kNone; } +template inline void -Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { +Tree::SetLeafVector( + int nid, const std::vector& node_leaf_vector) { const size_t end_oft = leaf_vector_offset_.Back(); const size_t new_end_oft = end_oft + node_leaf_vector.size(); if (end_oft != leaf_vector_.Size()) { @@ -547,28 +587,82 @@ Tree::SetLeafVector(int nid, const std::vector& node_leaf_vector) { node.split_type_ = SplitFeatureType::kNone; } +template inline std::unique_ptr Model::Create() { - std::unique_ptr model = std::make_unique(); + std::unique_ptr model = std::make_unique>(); + model->threshold_type_ = TypeToInfo(); + model->leaf_output_type_ = TypeToInfo(); return model; } +template +class ModelCreateImpl { + public: + inline static std::unique_ptr Dispatch() { + return Model::Create(); + } +}; + +inline std::unique_ptr +Model::Create(TypeInfo threshold_type, TypeInfo leaf_output_type) { + return DispatchWithModelTypes(threshold_type, leaf_output_type); +} + +template +class ModelDispatchImpl { + public: + template + inline static auto Dispatch(Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } + + template + inline static auto Dispatch(const Model* model, Func func) { + return func(*dynamic_cast*>(model)); + } +}; + +template +inline auto +Model::Dispatch(Func func) { + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); +} + +template +inline auto +Model::Dispatch(Func func) const { + return DispatchWithModelTypes(threshold_type_, leaf_output_type_, this, func); +} + inline std::vector Model::GetPyBuffer() { std::vector buffer; + buffer.push_back(GetPyBufferFromScalar(&threshold_type_)); + buffer.push_back(GetPyBufferFromScalar(&leaf_output_type_)); this->GetPyBuffer(&buffer); return buffer; } inline std::unique_ptr Model::CreateFromPyBuffer(std::vector frames) { - std::unique_ptr model = Model::Create(); - model->InitFromPyBuffer(frames.begin(), frames.end()); + using TypeInfoInt = std::underlying_type::type; + TypeInfo threshold_type, leaf_output_type; + if (frames.size() < 2) { + throw std::runtime_error("Insufficient number of frames: there must be at least two"); + } + InitScalarFromPyBuffer(&threshold_type, frames[0]); + InitScalarFromPyBuffer(&leaf_output_type, frames[1]); + + std::unique_ptr model = Model::Create(threshold_type, leaf_output_type); + model->InitFromPyBuffer(frames.begin() + 2, frames.end()); return model; } + +template inline void -ModelImpl::GetPyBuffer(std::vector* dest) { +ModelImpl::GetPyBuffer(std::vector* dest) { /* Header */ dest->push_back(GetPyBufferFromScalar(&num_feature)); dest->push_back(GetPyBufferFromScalar(&num_output_group)); @@ -577,13 +671,14 @@ ModelImpl::GetPyBuffer(std::vector* dest) { ¶m, "T{" _TREELITE_STR(TREELITE_MAX_PRED_TRANSFORM_LENGTH) "s=f=f}")); /* Body */ - for (Tree& tree : trees) { + for (Tree& tree : trees) { tree.GetPyBuffer(dest); } } +template inline void -ModelImpl::InitFromPyBuffer( +ModelImpl::InitFromPyBuffer( std::vector::iterator begin, std::vector::iterator end) { const size_t num_frame = std::distance(begin, end); /* Header */ diff --git a/include/treelite/typeinfo.h b/include/treelite/typeinfo.h index 7b1337ea..01ea7cfa 100644 --- a/include/treelite/typeinfo.h +++ b/include/treelite/typeinfo.h @@ -58,7 +58,7 @@ inline std::string TypeInfoToString(treelite::TypeInfo type) { * \return TypeInfo corresponding to the template type arg */ template -inline TypeInfo InferTypeInfoOf() { +inline TypeInfo TypeToInfo() { if (std::is_same::value) { return TypeInfo::kUInt32; } else if (std::is_same::value) {