Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple data types (will break up into smaller pull requests) #196

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2e823eb
Make Model and Tree classes into template classes
hcho3 Aug 29, 2020
edcd1f4
Add back size check for Node class; define format string for Node cla…
hcho3 Aug 30, 2020
c0d1abd
Update ModelBuilder interface + zero-copy serializer
hcho3 Aug 31, 2020
acc20be
Update C interface for model builder API
hcho3 Aug 31, 2020
34c2231
Fix lint checks
hcho3 Aug 31, 2020
86104d6
Add a missing header
hcho3 Aug 31, 2020
77477e7
Perform round-trip serialization test with all possible types
hcho3 Aug 31, 2020
b846f6a
Update native code template
hcho3 Aug 31, 2020
0f88e96
Move TypeInfo to a separate header + objtreelite_common
hcho3 Sep 1, 2020
81e9c7c
Move data.cc to objtreelite_common; start new DMatrix design
hcho3 Sep 1, 2020
e4bb3cb
New DMatrix design
hcho3 Sep 1, 2020
f64ddc8
Use new DMatrix in branch annotator; move DMatrix to objtreelite_common
hcho3 Sep 2, 2020
1c74cc5
Make ModelImpl a derived class of Model, simplifying dispatch
hcho3 Sep 2, 2020
e8f874f
Add functions to query data types used in a compiled predictor
hcho3 Sep 2, 2020
71d43aa
Remove single-instance prediction feature from the runtime
hcho3 Sep 2, 2020
914469b
Refactor the runtime API to use DMatrix
hcho3 Sep 3, 2020
13ee2b6
Get moder builder and runtime API working end-to-end
hcho3 Sep 3, 2020
7aef6db
Emit correct data type for leaf outputs
hcho3 Sep 4, 2020
f2715d9
Remove Model Type; consolidate type dispatching logic
hcho3 Sep 4, 2020
03498f0
Fix all scikit-learn tests: input must be float32 + thresholds are fl…
hcho3 Sep 4, 2020
88080ec
Fix lint
hcho3 Sep 4, 2020
7676edf
Expose type info as property of Predictor
hcho3 Sep 4, 2020
30ef99a
Fix build with CentOS 6 + GCC 5
hcho3 Sep 4, 2020
d0659f2
Move data/data.cc -> data.cc
hcho3 Sep 4, 2020
d2d4e3f
Fix build on MSVC
hcho3 Sep 4, 2020
c1119ff
Define _CRT_SECURE_NO_WARNINGS to remove unneeded warnings in MSVC
hcho3 Sep 4, 2020
0056ebe
Fix MSVC warnings
hcho3 Sep 4, 2020
40225be
In CMake pkg test, specify x64 for MSVC
hcho3 Sep 4, 2020
53bdf64
Fix lint
hcho3 Sep 4, 2020
01a6a9c
Fix Java/Scala runtime
hcho3 Sep 5, 2020
1827bdb
Fix lint
hcho3 Sep 5, 2020
f611533
Fix cmake import example
hcho3 Sep 5, 2020
973b4a6
Remove 'float32' suffix in Java methods
hcho3 Sep 5, 2020
b1e1e4b
Remove a void* handle that's unused
hcho3 Sep 10, 2020
e149dc4
Merge remote-tracking branch 'origin/mainline' into multi_type_support2
hcho3 Sep 10, 2020
6421fa2
New prediction runtime C API, to support multiple data types (Part of…
hcho3 Sep 15, 2020
dc18811
Refactor struct Model -> class Model + class ModelImpl (Part of #196)…
hcho3 Sep 17, 2020
fdcb519
Merge remote-tracking branch 'origin/multi_type_refactor_breakup' int…
hcho3 Sep 18, 2020
8138e13
Inline GetThresholdType() and GetLeafOutputType()
hcho3 Sep 18, 2020
269dd2e
Create template classes for model representations to support multiple…
hcho3 Sep 25, 2020
d9173df
Merge remote-tracking branch 'origin/multi_type_refactor_breakup' int…
hcho3 Sep 25, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ set (CMAKE_FIND_NO_INSTALL_PREFIX TRUE FORCE)
cmake_minimum_required (VERSION 3.14)
project(treelite LANGUAGES CXX C VERSION 0.93)

set(CMAKE_CXX_STANDARD 14)

# check MSVC version
if(MSVC)
if(MSVC_VERSION LESS 1910)
Expand Down
6 changes: 6 additions & 0 deletions cmake/ExternalLibs.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ FetchContent_Declare(
GIT_TAG v0.4
)
FetchContent_MakeAvailable(dmlccore)
target_compile_options(dmlc PRIVATE
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
if (TARGET dmlc_unit_tests)
target_compile_options(dmlc_unit_tests PRIVATE
-D_CRT_SECURE_NO_WARNINGS -D_CRT_SECURE_NO_DEPRECATE)
endif (TARGET dmlc_unit_tests)

FetchContent_Declare(
fmtlib
Expand Down
3 changes: 1 addition & 2 deletions include/treelite/annotator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ class BranchAnnotator {
* \param nthread number of threads to use
* \param verbose whether to produce extra messages
*/
void Annotate(const Model& model, const DMatrix* dmat,
int nthread, int verbose);
void Annotate(const Model& model, const DMatrix* dmat, int nthread, int verbose);
/*!
* \brief load branch annotation from a JSON file
* \param fi input stream
Expand Down
10 changes: 7 additions & 3 deletions include/treelite/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
#define TREELITE_BASE_H_

#include <cstdint>
#include <typeinfo>
#include <string>
#include <unordered_map>
#include <stdexcept>
#include "./typeinfo.h"

namespace treelite {

Expand All @@ -29,11 +31,12 @@ enum class Operator : int8_t {
kGT, /*!< operator > */
kGE, /*!< operator >= */
};
/*! \brief conversion table from string to operator, defined in optable.cc */

/*! \brief conversion table from string to Operator, defined in tables.cc */
extern const std::unordered_map<std::string, Operator> optable;

/*!
* \brief get string representation of comparsion operator
* \brief get string representation of comparison operator
* \param op comparison operator
* \return string representation
*/
Expand All @@ -56,7 +59,8 @@ inline std::string OpName(Operator op) {
* \param rhs float on the right hand side
* \return whether [lhs] [op] [rhs] is true or not
*/
inline bool CompareWithOp(tl_float lhs, Operator op, tl_float rhs) {
template <typename ElementType, typename ThresholdType>
inline bool CompareWithOp(ElementType lhs, Operator op, ThresholdType rhs) {
switch (op) {
case Operator::kEQ: return lhs == rhs;
case Operator::kLT: return lhs < rhs;
Expand Down
200 changes: 54 additions & 146 deletions include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
* opaque handles
* \{
*/
/*! \brief handle to a data matrix */
typedef void* DMatrixHandle;
/*! \brief handle to a decision tree ensemble model */
typedef void* ModelHandle;
/*! \brief handle to tree builder class */
Expand All @@ -31,100 +29,8 @@ typedef void* ModelBuilderHandle;
typedef void* AnnotationHandle;
/*! \brief handle to compiler class */
typedef void* CompilerHandle;
/*! \} */

/*!
* \defgroup dmatrix
* Data matrix interface
* \{
*/
/*!
* \brief create DMatrix from a file
* \param path file path
* \param format file format
* \param nthread number of threads to use
* \param verbose whether to produce extra messages
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromFile(const char* path,
const char* format,
int nthread,
int verbose,
DMatrixHandle* out);
/*!
* \brief create DMatrix from a (in-memory) CSR matrix
* \param data feature values
* \param col_ind feature indices
* \param row_ptr pointer to row headers
* \param num_row number of rows
* \param num_col number of columns
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromCSR(const float* data,
const unsigned* col_ind,
const size_t* row_ptr,
size_t num_row,
size_t num_col,
DMatrixHandle* out);
/*!
* \brief create DMatrix from a (in-memory) dense matrix
* \param data feature values
* \param num_row number of rows
* \param num_col number of columns
* \param missing_value value to represent missing value
* \param out the created DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixCreateFromMat(const float* data,
size_t num_row,
size_t num_col,
float missing_value,
DMatrixHandle* out);
/*!
* \brief get dimensions of a DMatrix
* \param handle handle to DMatrix
* \param out_num_row used to set number of rows
* \param out_num_col used to set number of columns
* \param out_nelem used to set number of nonzero entries
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetDimension(DMatrixHandle handle,
size_t* out_num_row,
size_t* out_num_col,
size_t* out_nelem);

/*!
* \brief produce a human-readable preview of a DMatrix
* Will print first and last 25 non-zero entries, along with their locations
* \param handle handle to DMatrix
* \param out_preview used to save the address of the string literal
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetPreview(DMatrixHandle handle,
const char** out_preview);

/*!
* \brief extract three arrays (data, col_ind, row_ptr) that define a DMatrix.
* \param handle handle to DMatrix
* \param out_data used to save pointer to array containing feature values
* \param out_col_ind used to save pointer to array containing feature indices
* \param out_row_ptr used to save pointer to array containing pointers to
* row headers
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixGetArrays(DMatrixHandle handle,
const float** out_data,
const uint32_t** out_col_ind,
const size_t** out_row_ptr);

/*!
* \brief delete DMatrix from memory
* \param handle handle to DMatrix
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle);
/*! \brief handle to a polymorphic value type, used in the model builder API */
typedef void* ValueHandle;
/*! \} */

/*!
Expand All @@ -142,11 +48,8 @@ TREELITE_DLL int TreeliteDMatrixFree(DMatrixHandle handle);
* \param out used to save handle for the created annotation
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteAnnotateBranch(ModelHandle model,
DMatrixHandle dmat,
int nthread,
int verbose,
AnnotationHandle* out);
TREELITE_DLL int TreeliteAnnotateBranch(
ModelHandle model, DMatrixHandle dmat, int nthread, int verbose, AnnotationHandle* out);
/*!
* \brief save branch annotation to a JSON file
* \param handle annotation to save
Expand Down Expand Up @@ -229,17 +132,15 @@ TREELITE_DLL int TreeliteCompilerFree(CompilerHandle handle);
* \param out loaded model
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteLoadLightGBMModel(const char* filename,
ModelHandle* out);
TREELITE_DLL int TreeliteLoadLightGBMModel(const char* filename, ModelHandle* out);
/*!
* \brief load a model file generated by XGBoost (dmlc/xgboost). The model file
* must contain a decision tree ensemble.
* \param filename name of model file
* \param out loaded model
* \return 0 for success, -1 for failure
*/
TREELITE_DLL int TreeliteLoadXGBoostModel(const char* filename,
ModelHandle* out);
TREELITE_DLL int TreeliteLoadXGBoostModel(const char* filename, ModelHandle* out);
/*!
* \brief load an XGBoost model from a memory buffer.
* \param buf memory buffer
Expand Down Expand Up @@ -293,12 +194,33 @@ TREELITE_DLL int TreeliteFreeModel(ModelHandle handle);
* Model builder interface: build trees incrementally
* \{
*/
/*!
* \brief Create a new Value object. Some model builder API functions accept this Value type to
* accommodate values of multiple types.
* \param init_value pointer to the value to be stored
* \param type Type of the value to be stored
* \param out newly created Value object
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderCreateValue(const void* init_value, const char* type,
ValueHandle* out);
/*!
* \brief Delete a Value object from memory
* \param handle pointer to the Value object to be deleted
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderDeleteValue(ValueHandle handle);
/*!
* \brief Create a new tree builder
* \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model
* must have the same type.
* \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the
* same type.
* \param out newly created tree builder
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteCreateTreeBuilder(TreeBuilderHandle* out);
TREELITE_DLL int TreeliteCreateTreeBuilder(const char* threshold_type, const char* leaf_output_type,
TreeBuilderHandle* out);
/*!
* \brief Delete a tree builder from memory
* \param handle tree builder to remove
Expand All @@ -311,24 +233,21 @@ TREELITE_DLL int TreeliteDeleteTreeBuilder(TreeBuilderHandle handle);
* \param node_key unique integer key to identify the new node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderCreateNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Remove a node from a tree
* \param handle tree builder
* \param node_key unique integer key to identify the node to be removed
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderDeleteNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Set a node as the root of a tree
* \param handle tree builder
* \param node_key unique integer key to identify the root node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle,
int node_key);
TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle, int node_key);
/*!
* \brief Turn an empty node into a test node with numerical split.
* The test is in the form [feature value] OP [threshold]. Depending on the
Expand All @@ -345,12 +264,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetRootNode(TreeBuilderHandle handle,
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const char* opname,
float threshold, int default_left,
int left_child_key,
int right_child_key);
TreeBuilderHandle handle, int node_key, unsigned feature_id, const char* opname,
ValueHandle threshold, int default_left, int left_child_key, int right_child_key);
/*!
* \brief Turn an empty node into a test node with categorical split.
* A list defines all categories that would be classified as the left side.
Expand All @@ -368,13 +283,9 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const unsigned int* left_categories,
size_t left_categories_len,
int default_left,
int left_child_key,
int right_child_key);
TreeBuilderHandle handle, int node_key, unsigned feature_id,
const unsigned int* left_categories, size_t left_categories_len, int default_left,
int left_child_key, int right_child_key);
/*!
* \brief Turn an empty node into a leaf node
* \param handle tree builder
Expand All @@ -383,9 +294,8 @@ TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
* \param leaf_value leaf value (weight) of the leaf node
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
int node_key,
float leaf_value);
TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(
TreeBuilderHandle handle, int node_key, ValueHandle leaf_value);
/*!
* \brief Turn an empty node into a leaf vector node
* The leaf vector (collection of multiple leaf weights per leaf node) is
Expand All @@ -397,29 +307,27 @@ TREELITE_DLL int TreeliteTreeBuilderSetLeafNode(TreeBuilderHandle handle,
* \param leaf_vector_len length of leaf_vector
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(TreeBuilderHandle handle,
int node_key,
const float* leaf_vector,
size_t leaf_vector_len);
TREELITE_DLL int TreeliteTreeBuilderSetLeafVectorNode(
TreeBuilderHandle handle, int node_key, const ValueHandle* leaf_vector, size_t leaf_vector_len);
/*!
* \brief Create a new model builder
* \param num_feature number of features used in model being built. We assume
* that all feature indices are between 0 and
* (num_feature - 1).
* \param num_output_group number of output groups. Set to 1 for binary
* classification and regression; >1 for multiclass
* classification
* \param random_forest_flag whether the model is a random forest. Set to 0 if
* the model is gradient boosted trees. Any nonzero
* value shall indicate that the model is a
* random forest.
* \param num_feature number of features used in model being built. We assume that all feature
* indices are between 0 and (num_feature - 1).
* \param num_output_group number of output groups. Set to 1 for binary classification and
* regression; >1 for multiclass classification
* \param random_forest_flag whether the model is a random forest. Set to 0 if the model is
* gradient boosted trees. Any nonzero value shall indicate that the
* model is a random forest.
* \param threshold_type Type of thresholds in numerical splits. All thresholds in a given model
* must have the same type.
* \param leaf_output_type Type of leaf outputs. All leaf outputs in a given model must have the
* same type.
* \param out newly created model builder
* \return 0 for success; -1 for failure
*/
TREELITE_DLL int TreeliteCreateModelBuilder(int num_feature,
int num_output_group,
int random_forest_flag,
ModelBuilderHandle* out);
TREELITE_DLL int TreeliteCreateModelBuilder(
int num_feature, int num_output_group, int random_forest_flag, const char* threshold_type,
const char* leaf_output_type, ModelBuilderHandle* out);
/*!
* \brief Set a model parameter
* \param handle model builder
Expand Down
Loading