Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add all objectives currently supported by LightGBM #215

Merged
merged 4 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions runtime/python/treelite_runtime/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
'float64': np.float64
}

_NUMPY_TYPE_TABLE_INV = {
np.uint32: 'unit32',
np.float32: 'float32',
np.float64: 'float64'
}


def type_info_to_ctypes_type(type_info):
"""Obtain ctypes type corresponding to a given TypeInfo"""
Expand All @@ -39,7 +33,14 @@ def type_info_to_numpy_type(type_info):

def numpy_type_to_type_info(type_info):
"""Obtain TypeInfo corresponding to a given NumPy type"""
return _NUMPY_TYPE_TABLE_INV[type_info]
if type_info == np.uint32:
return 'unit32'
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
elif type_info == np.float32:
return 'float32'
elif type_info == np.float64:
return 'float64'
else:
raise ValueError('Unrecognized NumPy type: {type_info}')


class TreeliteRuntimeError(Exception):
Expand Down
1 change: 1 addition & 0 deletions src/compiler/native/header_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ R"TREELITETEMPLATE(

const char* const header_template =
R"TREELITETEMPLATE(
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
Expand Down
10 changes: 10 additions & 0 deletions src/compiler/native/pred_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type
"threshold_type"_a = native::TypeInfoToCTypeString(model.GetThresholdType()));
}

inline std::string signed_square(const Model& model) {
const TypeInfo threshold_type = model.GetThresholdType();
return fmt::format(
R"TREELITETEMPLATE(static inline {threshold_type} pred_transform({threshold_type} margin) {{
return {copysign}(margin * margin, margin);
}})TREELITETEMPLATE",
"threshold_type"_a = native::TypeInfoToCTypeString(threshold_type),
"copysign"_a = native::CCopySignForTypeInfo(threshold_type));
}

inline std::string sigmoid(const Model& model) {
const float alpha = model.param.sigmoid_alpha;
const TypeInfo threshold_type = model.GetThresholdType();
Expand Down
22 changes: 22 additions & 0 deletions src/compiler/native/typeinfo_ctypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@ inline std::string CExpForTypeInfo(TypeInfo type) {
}
}

/*!
* \brief Look up the correct variant of copysign() in C that should be used with a given type
* \param info a type info
* \return string representation
*/
inline std::string CCopySignForTypeInfo(TypeInfo type) {
switch (type) {
case TypeInfo::kInvalid:
case TypeInfo::kUInt32:
throw std::runtime_error(std::string("Invalid type: ") + TypeInfoToString(type));
return "";
case TypeInfo::kFloat32:
return "copysignf";
case TypeInfo::kFloat64:
return "copysign";
default:
throw std::runtime_error(std::string("Unrecognized type: ")
+ std::to_string(static_cast<int>(type)));
return "";
}
}

/*!
* \brief Look up the correct variant of log1p() in C that should be used with a given type
* \param info a type info
Expand Down
13 changes: 9 additions & 4 deletions src/compiler/pred_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ FUNC_NAME(const std::string& backend, const Model& model) { \
}

TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(identity)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(signed_square)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(sigmoid)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(exponential)
TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(logarithm_one_plus_exp)
Expand All @@ -43,23 +44,27 @@ TREELITE_PRED_TRANSFORM_REGISTRY_DEFAULT_TEMPLATE(multiclass_ova)
const std::unordered_map<std::string, PredTransformFuncGenerator>
pred_transform_db = {
PRED_TRANSFORM_FUNC(identity),
PRED_TRANSFORM_FUNC(signed_square),
PRED_TRANSFORM_FUNC(sigmoid),
PRED_TRANSFORM_FUNC(exponential),
PRED_TRANSFORM_FUNC(logarithm_one_plus_exp)
};
/*! [pred_transform_db]
- identity
do not transform. The output will be a vector of length
Do not transform. The output will be a vector of length
[number of data points] that contains the margin score for every data point.
- signed_square
Apply the function f(x) = sign(x) * (x**2) element-wise to the margin vector. The
output will be a vector of length [number of data points].
- sigmoid
apply the sigmoid function element-wise to the margin vector. The output
Apply the sigmoid function element-wise to the margin vector. The output
will be a vector of length [number of data points] that contains the
probability of each data point belonging to the positive class.
- exponential
apply the exponential function (exp) element-wise to the margin vector. The
Apply the exponential function (exp) element-wise to the margin vector. The
output will be a vector of length [number of data points].
- logarithm_one_plus_exp
apply the function f(x) = log(1 + exp(x)) element-wise to the margin vector.
Apply the function f(x) = log(1 + exp(x)) element-wise to the margin vector.
The output will be a vector of length [number of data points].
[pred_transform_db] */

Expand Down
19 changes: 18 additions & 1 deletion src/frontend/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,9 +514,26 @@ inline std::unique_ptr<treelite::Model> ParseStream(dmlc::Stream* fi) {
} else if (obj_name_ == "xentlambda" || obj_name_ == "cross_entropy_lambda") {
std::strncpy(model->param.pred_transform, "logarithm_one_plus_exp",
sizeof(model->param.pred_transform));
} else {
} else if (obj_name_ == "poisson" || obj_name_ == "gamma" || obj_name_ == "tweedie") {
std::strncpy(model->param.pred_transform, "exponential",
sizeof(model->param.pred_transform));
} else if (obj_name_ == "regression" || obj_name_ == "regression_l1" || obj_name_ == "huber"
|| obj_name_ == "fair" || obj_name_ == "quantile" || obj_name_ == "mape") {
// Regression family
bool sqrt = (std::find(obj_param_.cbegin(), obj_param_.cend(), "sqrt") != obj_param_.cend());
if (sqrt) {
std::strncpy(model->param.pred_transform, "signed_square",
sizeof(model->param.pred_transform));
} else {
std::strncpy(model->param.pred_transform, "identity",
sizeof(model->param.pred_transform));
}
} else if (obj_name_ == "lambdarank" || obj_name_ == "rank_xendcg" || obj_name_ == "custom") {
// Ranking family, or a custom user-defined objective
std::strncpy(model->param.pred_transform, "identity",
sizeof(model->param.pred_transform));
} else {
LOG(FATAL) << "Unrecognized objective: " << obj_name_;
}

// traverse trees
Expand Down
33 changes: 33 additions & 0 deletions tests/python/test_lightgbm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@
pytest.skip('LightGBM not installed; skipping', allow_module_level=True)


@pytest.mark.skipif(not has_sklearn(), reason='Needs scikit-learn')
@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
@pytest.mark.parametrize('objective', ['regression', 'regression_l1', 'huber'])
@pytest.mark.parametrize('reg_sqrt', [True, False])
def test_lightgbm_regression(tmpdir, objective, reg_sqrt, toolchain):
# pylint: disable=too-many-locals
"""Test a regressor"""
model_path = os.path.join(tmpdir, 'boston_lightgbm.txt')

from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=False)
dtrain = lightgbm.Dataset(X_train, y_train, free_raw_data=False)
dtest = lightgbm.Dataset(X_test, y_test, reference=dtrain, free_raw_data=False)
param = {'task': 'train', 'boosting_type': 'gbdt', 'objective': objective, 'reg_sqrt': reg_sqrt,
'metric': 'rmse', 'num_leaves': 31, 'learning_rate': 0.05}
bst = lightgbm.train(param, dtrain, num_boost_round=10, valid_sets=[dtrain, dtest],
valid_names=['train', 'test'])
bst.save_model(model_path)

model = treelite.Model.load(model_path, model_format='lightgbm')
libpath = os.path.join(tmpdir, f'boston_{objective}' + _libext())
model.export_lib(toolchain=toolchain, libpath=libpath, params={'quantize': 1}, verbose=True)
predictor = treelite_runtime.Predictor(libpath=libpath, verbose=True)

dmat = treelite_runtime.DMatrix(X_test, dtype='float64')
out_pred = predictor.predict(dmat)
expected_pred = bst.predict(X_test)
np.testing.assert_almost_equal(out_pred, expected_pred, decimal=5)


@pytest.mark.skipif(not has_sklearn(), reason='Needs scikit-learn')
@pytest.mark.parametrize('toolchain', os_compatible_toolchains())
@pytest.mark.parametrize('objective', ['multiclass', 'multiclassova'])
Expand Down