Skip to content

Commit

Permalink
[python-package] support customizing Dataset creation in Booster.refi…
Browse files Browse the repository at this point in the history
…t() (fixes #3038) (#4894)

* feat: refit additional kwargs for dataset and predict

* test: kwargs for refit method

* fix: __init__ got multiple values for argument

* fix: pycodestyle E302 error

* refactor: dataset_params to avoid breaking change

* refactor: expose all Dataset params in refit

* feat: dataset_params updates new_params

* fix: remove unnecessary params to test

* test: parameters input are the same

* docs: address StrikeRUS changes

* test: refit test changes in train dataset

* test: set init_score and decay_rate to zero
  • Loading branch information
TremaMiguel authored Jan 22, 2022
1 parent f85dfa2 commit e6a2f71
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 2 deletions.
61 changes: 59 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3511,7 +3511,21 @@ def predict(self, data, start_iteration=0, num_iteration=None,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)

def refit(self, data, label, decay_rate=0.9, **kwargs):
def refit(
self,
data,
label,
decay_rate=0.9,
reference=None,
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
dataset_params=None,
free_raw_data=True,
**kwargs
):
"""Refit the existing Booster by new data.
Parameters
Expand All @@ -3524,6 +3538,35 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
decay_rate : float, optional (default=0.9)
Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
reference : Dataset or None, optional (default=None)
Reference for ``data``.
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each ``data`` instance. Weight should be non-negative values because the Hessian
value multiplied by weight is supposed to be non-negative.
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query size for ``data``.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
init_score : list, list of lists (for multi-class task), numpy array, pandas Series, pandas DataFrame (for multi-class task), or None, optional (default=None)
Init score for ``data``.
feature_name : list of str, or 'auto', optional (default="auto")
Feature names for ``data``.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of str or int, or 'auto', optional (default="auto")
Categorical features for ``data``.
If list of int, interpreted as indices.
If list of str, interpreted as feature names (need to specify ``feature_name`` as well).
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
dataset_params : dict or None, optional (default=None)
Other parameters for Dataset ``data``.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset for ``data``.
**kwargs
Other parameters for refit.
These parameters will be passed to ``predict`` method.
Expand All @@ -3535,6 +3578,8 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
"""
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
Expand All @@ -3548,7 +3593,19 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
default_value=None
)
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, params=new_params)
new_params.update(dataset_params)
train_set = Dataset(
data=data,
label=label,
reference=reference,
weight=weight,
group=group,
init_score=init_score,
feature_name=feature_name,
categorical_feature=categorical_feature,
params=new_params,
free_raw_data=free_raw_data,
)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
# Copy models
Expand Down
34 changes: 34 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1722,6 +1722,40 @@ def test_refit():
assert err_pred > new_err_pred


def test_refit_dataset_params():
# check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True)
lgb_train = lgb.Dataset(X, y, init_score=np.zeros(y.size))
train_params = {
'objective': 'binary',
'verbose': -1,
'seed': 123
}
gbm = lgb.train(train_params, lgb_train, num_boost_round=10)
non_weight_err_pred = log_loss(y, gbm.predict(X))
refit_weight = np.random.rand(y.shape[0])
dataset_params = {
'max_bin': 260,
'min_data_in_bin': 5,
'data_random_seed': 123,
}
new_gbm = gbm.refit(
data=X,
label=y,
weight=refit_weight,
dataset_params=dataset_params,
decay_rate=0.0,
)
weight_err_pred = log_loss(y, new_gbm.predict(X))
train_set_params = new_gbm.train_set.get_params()
stored_weights = new_gbm.train_set.get_weight()
assert weight_err_pred != non_weight_err_pred
assert train_set_params["max_bin"] == 260
assert train_set_params["min_data_in_bin"] == 5
assert train_set_params["data_random_seed"] == 123
np.testing.assert_allclose(stored_weights, refit_weight)


def test_mape_rf():
X, y = load_boston(return_X_y=True)
params = {
Expand Down

0 comments on commit e6a2f71

Please sign in to comment.