Skip to content

Commit

Permalink
Add linear leaf models to json output (fixes #4186) (#4329)
Browse files Browse the repository at this point in the history
* Add linear leaf models to json output

* Add closing bracket

* Move test into test_engine.py and add asserts

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <[email protected]>

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <[email protected]>

* Update tests/python_package_test/test_engine.py

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
btrotta and StrikerRUS committed Jun 3, 2021
1 parent 3dd4a3f commit 1b5bec0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
3 changes: 3 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ class Tree {
/*! \brief Serialize this object to json*/
std::string ToJSON() const;

/*! \brief Serialize linear model of tree node to json*/
std::string LinearModelToJSON(int index) const;

/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool predict_leaf_index) const;

Expand Down
38 changes: 35 additions & 3 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,39 @@ std::string Tree::ToJSON() const {
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
if (is_linear_) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
str_buf << LinearModelToJSON(0) << "}" << "\n";
} else {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
}
} else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
}
return str_buf.str();
}

std::string Tree::LinearModelToJSON(int index) const {
std::stringstream str_buf;
Common::C_stringstream(str_buf);
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << "\"leaf_const\":" << leaf_const_[index] << "," << "\n";
int num_features = static_cast<int>(leaf_features_[index].size());
if (num_features > 0) {
str_buf << "\"leaf_features\":[";
for (int i = 0; i < num_features - 1; ++i) {
str_buf << leaf_features_[index][i] << ", ";
}
str_buf << leaf_features_[index][num_features - 1] << "]" << ", " << "\n";
str_buf << "\"leaf_coeff\":[";
for (int i = 0; i < num_features - 1; ++i) {
str_buf << leaf_coeff_[index][i] << ", ";
}
str_buf << leaf_coeff_[index][num_features - 1] << "]" << "\n";
} else {
str_buf << "\"leaf_features\":[],\n";
str_buf << "\"leaf_coeff\":[]\n";
}
return str_buf.str();
}

Expand Down Expand Up @@ -479,10 +507,14 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
if (is_linear_) {
str_buf << "\"leaf_count\":" << leaf_count_[index] << "," << '\n';
str_buf << LinearModelToJSON(index);
} else {
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
}
str_buf << "}";
}

return str_buf.str();
}

Expand Down
25 changes: 25 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2793,3 +2793,28 @@ def test_reset_params_works_with_metric_num_class_and_boosting():
expected_params = dict(dataset_params, **booster_params)
assert bst.params == expected_params
assert new_bst.params == expected_params


def test_dump_model():
X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
assert "leaf_features" not in dumped_model_str
assert "leaf_coeff" not in dumped_model_str
assert "leaf_const" not in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str
params['linear_tree'] = True
train_data = lgb.Dataset(X, label=y)
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
assert "leaf_features" in dumped_model_str
assert "leaf_coeff" in dumped_model_str
assert "leaf_const" in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str

0 comments on commit 1b5bec0

Please sign in to comment.