-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add linear leaf models to json output (fixes #4186) #4329
Changes from 2 commits
31bcb45
1480c4f
e2a8d81
2302a47
fd43e2c
550608e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -411,3 +411,18 @@ def test_list_to_1d_numpy(y, dtype): | |||
result = lgb.basic.list_to_1d_numpy(y, dtype=dtype) | ||||
assert result.size == 10 | ||||
assert result.dtype == dtype | ||||
|
||||
|
||||
def test_dump_model(): | ||||
X, y = load_breast_cancer(return_X_y=True) | ||||
train_data = lgb.Dataset(X, label=y) | ||||
params = { | ||||
"objective": "binary", | ||||
"verbose": -1 | ||||
} | ||||
bst = lgb.train(params, train_data, num_boost_round=5) | ||||
bst.dump_model(5, 0) | ||||
params['linear_tree'] = True | ||||
train_data = lgb.Dataset(X, label=y) | ||||
bst = lgb.train(params, train_data, num_boost_round=5) | ||||
bst.dump_model(5, 0) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is done already in LightGBM/python-package/lightgbm/basic.py Line 3090 in 36957ed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this enhancement!
Please move this test into
test_engine.py
file becauselgb.train()
function is fromengine.py
module.Also I think it will be useful to add some asserts into this test. Something like
assert 'leaf_coeff' in dumped_model
for linear model andassert 'leaf_coeff' not in dumped_model
for ordinary one.