Skip to content

Commit

Permalink
Add all objectives currently supported by LightGBM (#215)
Browse files Browse the repository at this point in the history
* Do not use NumPy types as a dict key; use explicit comparison

* Add all objectives currently supported by LightGBM

* Add unit test for LightGBM regressor (with sqrt option)

* Update runtime/python/treelite_runtime/util.py

Co-authored-by: William Hicks <[email protected]>

Co-authored-by: William Hicks <[email protected]>
  • Loading branch information
hcho3 and wphicks authored Oct 26, 2020
1 parent 9b42c39 commit c2dc4dd
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 12 deletions.
15 changes: 8 additions & 7 deletions runtime/python/treelite_runtime/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
'float64': np.float64
}

_NUMPY_TYPE_TABLE_INV = {
np.uint32: 'unit32',
np.float32: 'float32',
np.float64: 'float64'
}


def type_info_to_ctypes_type(type_info):
"""Obtain ctypes type corresponding to a given TypeInfo"""
Expand All @@ -39,7 +33,14 @@ def type_info_to_numpy_type(type_info):

def numpy_type_to_type_info(type_info):
"""Obtain TypeInfo corresponding to a given NumPy type"""
return _NUMPY_TYPE_TABLE_INV[type_info]
if type_info == np.uint32:
return 'uint32'
elif type_info == np.float32:
return 'float32'
elif type_info == np.float64:
return 'float64'
else:
raise ValueError('Unrecognized NumPy type: {type_info}')


class TreeliteRuntimeError(Exception):
Expand Down
1 change: 1 addition & 0 deletions src/compiler/native/header_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ R"TREELITETEMPLATE(

const char* const header_template =
R"TREELITETEMPLATE(
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
Expand Down
10 changes: 10 additions & 0 deletions src/compiler/native/pred_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type
"threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
}

inline std::string signed_square(const Model& model) {
const TypeInfo threshold_type = model.GetThresholdType();
return fmt::format(
R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
return {copysign}(margin * margin, margin);
}})TREELITETEMPLATE",
"threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
"copysign"_a = native::CCopySignForTypeInfo(threshold_type));
}

inline std::string sigmoid(const Model& model) {
const float alpha = model.param.sigmoid_alpha;
const TypeInfo threshold_type = model.GetThresholdType();
Expand Down
22 changes: 22 additions & 0 deletions src/compiler/native/typeinfo_ctypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ inline std::string CExpForTypeInfo(TypeInfo type) {
}
}

/*!
* \brief Look up the correct variant of copysign() in C that should be used with a given type
* \param info a type info
* \return string representation
*/
inline std::string CCopySignForTypeInfo(TypeInfo type) {
switch (type) {
case TypeInfo::kInvalid:
case TypeInfo::kUInt32:
throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type));
return "";
case TypeInfo::kFloat32:
return "copysignf";
case TypeInfo::kFloat64:
return "copysign";
default:
throw std::runtime_error(std::string("Unrecognized type: ")
+ std::to_string(static_cast<int>(type)));
return "";
}
}

/*!
* \brief Look up the correct variant of log1p() in C that should be used with a given type
* \param info a type info
Expand Down
13 changes: 9 additions & 4 deletions src/compiler/pred_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ FUNC_NAME(const std::string& backend, const Model& model) { \
}

TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
Expand All @@ -43,23 +44,27 @@ TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
const std::unordered_map<std::string, PredTransformFuncGenerator>
pred_transform_db = {
PRED_TRANSFORM_FUNC(identity),
PRED_TRANSFORM_FUNC(signed_square),
PRED_TRANSFORM_FUNC(sigmoid),
PRED_TRANSFORM_FUNC(exponential),
PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
};
/*! [pred_transform_db]
- identity
do not transform. The output will be a vector of length
Do not transform. The output will be a vector of length
[number of data points] that contains the margin score for every data point.
- signed_square
Apply the function f(x) = sign(x) * (x**2) element-wise to the margin vector. The
output will be a vector of length [number of data points].
- sigmoid
apply the sigmoid function element-wise to the margin vector. The output
Apply the sigmoid function element-wise to the margin vector. The output
will be a vector of length [number of data points] that contains the
probability of each data point belonging to the positive class.
- exponential
apply the exponential function (exp) element-wise to the margin vector. The
Apply the exponential function (exp) element-wise to the margin vector. The
output will be a vector of length [number of data points].
- logarithm_one_plus_exp
apply the function f(x) = log(1 + exp(x)) element-wise to the margin vector.
Apply the function f(x) = log(1 + exp(x)) element-wise to the margin vector.
The output will be a vector of length [number of data points].
[pred_transform_db] */

Expand Down
19 changes: 18 additions & 1 deletion src/frontend/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,26 @@ inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
} else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp",
sizeof(model->param.pred_transform));
} else {
} else if (obj_name_ == "poisson" || obj_name_ == "gamma" || obj_name_ == "tweedie") {
std::strncpy(model->param.pred_transform, "exponential",
sizeof(model->param.pred_transform));
} else if (obj_name_ == "regression" || obj_name_ == "regression_l1" || obj_name_ == "huber"
|| obj_name_ == "fair" || obj_name_ == "quantile" || obj_name_ == "mape") {
// Regression family
bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(), "sqrt") != obj_param_.cend());
if (sqrt) {
std::strncpy(model->param.pred_transform, "signed_square",
sizeof(model->param.pred_transform));
} else {
std::strncpy(model->param.pred_transform, "identity",
sizeof(model->param.pred_transform));
}
} else if (obj_name_ == "lambdarank" || obj_name_ == "rank_xendcg" || obj_name_ == "custom") {
// Ranking family, or a custom user-defined objective
std::strncpy(model->param.pred_transform, "identity",
sizeof(model->param.pred_transform));
} else {
LOG(FATAL) << "Unrecognized objective: " << obj_name_;
}

// traverse trees
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_lightgbm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@
pytest.skip('LightGBM not installed; skipping', allow_module_level=True)


@pytest.mark.skipif(not has_sklearn(), reason='Needs scikit-learn')
@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
@pytest.mark.parametrize('objective', ['regression', 'regression_l1', 'huber'])
@pytest.mark.parametrize('reg_sqrt', [True, False])
def test_lightgbm_regression(tmpdir, objective, reg_sqrt, toolchain):
# pylint: disable=too-many-locals
"""Test a regressor"""
model_path = os.path.join(tmpdir, 'boston_lightgbm.txt')

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
dtrain = lightgbm.Dataset(X_train, y_train, free_raw_data=False)
dtest = lightgbm.Dataset(X_test, y_test, reference=dtrain, free_raw_data=False)
param = {'task': 'train', 'boosting_type': 'gbdt', 'objective': objective, 'reg_sqrt': reg_sqrt,
'metric': 'rmse', 'num_leaves': 31, 'learning_rate': 0.05}
bst = lightgbm.train(param, dtrain, num_boost_round=10, valid_sets=[dtrain, dtest],
valid_names=['train', 'test'])
bst.save_model(model_path)

model = treelite.Model.load(model_path, model_format='lightgbm')
libpath = os.path.join(tmpdir, f'boston_{objective}' + _libext())
model.export_lib(toolchain=toolchain, libpath=libpath, params={'quantize': 1}, verbose=True)
predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True)

dmat = treelite_runtime.DMatrix(X_test, dtype='float64')
out_pred = predictor.predict(dmat)
expected_pred = bst.predict(X_test)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


@pytest.mark.skipif(not has_sklearn(), reason='Needs scikit-learn')
@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
@pytest.mark.parametrize('objective', ['multiclass', 'multiclassova'])
Expand Down

0 comments on commit c2dc4dd

Please sign in to comment.