Skip to content

Commit

Permalink
Create template classes for model representations to support multiple…
Browse files Browse the repository at this point in the history
… data types (Part of #196) (#198)

* Create template classes for model representations to support multiple data types

* Update include/treelite/tree.h

Co-authored-by: William Hicks <[email protected]>

* Address review comments from @canonizer

* Address more comments from @canonizer

Co-authored-by: William Hicks <[email protected]>
Co-authored-by: William Hicks <[email protected]>
Co-authored-by: Andy Adinets <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2020
1 parent dc18811 commit 269dd2e
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 42 deletions.
7 changes: 5 additions & 2 deletions include/treelite/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#define TREELITE_BASE_H_

#include <cstdint>
#include <typeinfo>
#include <string>
#include <unordered_map>
#include <stdexcept>
#include "./typeinfo.h"

namespace treelite {

Expand All @@ -33,7 +35,7 @@ enum class Operator : int8_t {
extern const std::unordered_map<std::string, Operator> optable;

/*!
* \brief get string representation of comparsion operator
* \brief get string representation of comparison operator
* \param op comparison operator
* \return string representation
*/
Expand All @@ -56,7 +58,8 @@ inline std::string OpName(Operator op) {
* \param rhs float on the right hand side
* \return whether [lhs] [op] [rhs] is true or not
*/
inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) {
template <typename ElementType, typename ThresholdType>
inline bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs) {
switch (op) {
case Operator::kEQ: return lhs == rhs;
case Operator::kLT: return lhs < rhs;
Expand Down
58 changes: 44 additions & 14 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ContiguousArray {
};

/*! \brief in-memory representation of a decision tree */
template <typename ThresholdType, typename LeafOutputType>
class Tree {
public:
/*! \brief tree node */
Expand All @@ -87,8 +88,8 @@ class Tree {
inline void Init();
/*! \brief store either leaf value or decision threshold */
union Info {
tl_float leaf_value; // for leaf nodes
tl_float threshold; // for non-leaf nodes
LeafOutputType leaf_value; // for leaf nodes
ThresholdType threshold; // for non-leaf nodes
};
/*! \brief pointer to left and right children */
int32_t cleft_, cright_;
Expand Down Expand Up @@ -137,7 +138,19 @@ class Tree {
};

static_assert(std::is_pod<Node>::value, "Node must be a POD type");
static_assert(sizeof(Node) == 48, "Node must be 48 bytes");
static_assert(std::is_same<ThresholdType, float>::value
|| std::is_same<ThresholdType, double>::value,
"ThresholdType must be either float32 or float64");
static_assert(std::is_same<LeafOutputType, uint32_t>::value
|| std::is_same<LeafOutputType, float>::value
|| std::is_same<LeafOutputType, double>::value,
"LeafOutputType must be one of uint32_t, float32 or float64");
static_assert(std::is_same<ThresholdType, LeafOutputType>::value
|| std::is_same<LeafOutputType, uint32_t>::value,
"Unsupported combination of ThresholdType and LeafOutputType");
static_assert((std::is_same<ThresholdType, float>::value && sizeof(Node) == 48)
|| (std::is_same<ThresholdType, double>::value && sizeof(Node) == 56),
"Node size incorrect");

Tree() = default;
~Tree() = default;
Expand All @@ -146,14 +159,15 @@ class Tree {
Tree(Tree&&) noexcept = default;
Tree& operator=(Tree&&) noexcept = default;

inline const char* GetFormatStringForNode();
inline void GetPyBuffer(std::vector<PyBufferFrame>* dest);
inline void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
std::vector<PyBufferFrame>::iterator end);

private:
// vector of nodes
ContiguousArray<Node> nodes_;
ContiguousArray<tl_float> leaf_vector_;
ContiguousArray<LeafOutputType> leaf_vector_;
ContiguousArray<size_t> leaf_vector_offset_;
ContiguousArray<uint32_t> left_categories_;
ContiguousArray<size_t> left_categories_offset_;
Expand Down Expand Up @@ -225,19 +239,19 @@ class Tree {
* \brief get leaf value of the leaf node
* \param nid ID of node being queried
*/
inline tl_float LeafValue(int nid) const {
inline LeafOutputType LeafValue(int nid) const {
return (nodes_[nid].info_).leaf_value;
}
/*!
* \brief get leaf vector of the leaf node; useful for multi-class random forest classifier
* \param nid ID of node being queried
*/
inline std::vector<tl_float> LeafVector(int nid) const {
inline std::vector<LeafOutputType> LeafVector(int nid) const {
if (nid > leaf_vector_offset_.Size()) {
throw std::runtime_error("nid too large");
}
return std::vector<tl_float>(&leaf_vector_[leaf_vector_offset_[nid]],
&leaf_vector_[leaf_vector_offset_[nid + 1]]);
return std::vector<LeafOutputType>(&leaf_vector_[leaf_vector_offset_[nid]],
&leaf_vector_[leaf_vector_offset_[nid + 1]]);
}
/*!
* \brief tests whether the leaf node has a non-empty leaf vector
Expand All @@ -253,7 +267,7 @@ class Tree {
* \brief get threshold of the node
* \param nid ID of node being queried
*/
inline tl_float Threshold(int nid) const {
inline ThresholdType Threshold(int nid) const {
return (nodes_[nid].info_).threshold;
}
/*!
Expand Down Expand Up @@ -346,7 +360,7 @@ class Tree {
* \param cmp comparison operator to compare between feature value and
* threshold
*/
inline void SetNumericalSplit(int nid, unsigned split_index, tl_float threshold,
inline void SetNumericalSplit(int nid, unsigned split_index, ThresholdType threshold,
bool default_left, Operator cmp);
/*!
* \brief create a categorical split
Expand All @@ -365,13 +379,13 @@ class Tree {
* \param nid ID of node being updated
* \param value leaf value
*/
inline void SetLeaf(int nid, tl_float value);
inline void SetLeaf(int nid, LeafOutputType value);
/*!
* \brief set the leaf vector of the node; useful for multi-class random forest classifier
* \param nid ID of node being updated
* \param leaf_vector leaf vector
*/
inline void SetLeafVector(int nid, const std::vector<tl_float>& leaf_vector);
inline void SetLeafVector(int nid, const std::vector<LeafOutputType>& leaf_vector);
/*!
* \brief set the hessian sum of the node
* \param nid ID of node being updated
Expand Down Expand Up @@ -474,12 +488,25 @@ class Model {
/*! \brief disable copy; use default move */
Model() = default;
virtual ~Model() = default;
inline static std::unique_ptr<Model> Create();
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
Model(Model&&) = default;
Model& operator=(Model&&) = default;

template <typename ThresholdType, typename LeafOutputType>
inline static std::unique_ptr<Model> Create();
inline static std::unique_ptr<Model> Create(TypeInfo threshold_type, TypeInfo leaf_output_type);
inline TypeInfo GetThresholdType() const {
return threshold_type_;
}
inline TypeInfo GetLeafOutputType() const {
return leaf_output_type_;
}
template <typename Func>
inline auto Dispatch(Func func);
template <typename Func>
inline auto Dispatch(Func func) const;

virtual size_t GetNumTree() const = 0;
virtual void SetTreeLimit(size_t limit) = 0;
virtual void ReferenceSerialize(dmlc::Stream* fo) const = 0;
Expand All @@ -502,16 +529,19 @@ class Model {
ModelParam param;

private:
TypeInfo threshold_type_;
TypeInfo leaf_output_type_;
// Internal functions for serialization
virtual void GetPyBuffer(std::vector<PyBufferFrame>* dest) = 0;
virtual void InitFromPyBuffer(std::vector<PyBufferFrame>::iterator begin,
std::vector<PyBufferFrame>::iterator end) = 0;
};

template <typename ThresholdType, typename LeafOutputType>
class ModelImpl : public Model {
public:
/*! \brief member trees */
std::vector<Tree> trees;
std::vector<Tree<ThresholdType, LeafOutputType>> trees;

/*! \brief disable copy; use default move */
ModelImpl() = default;
Expand Down
Loading

0 comments on commit 269dd2e

Please sign in to comment.