From 8e31870b0fe4125751a89baaa43a222f657f6643 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Fri, 27 Aug 2021 20:50:46 -0500 Subject: [PATCH 01/13] initial changes --- python-package/lightgbm/callback.py | 32 +++++++++-- python-package/lightgbm/engine.py | 5 +- tests/python_package_test/test_engine.py | 70 +++++++++++++++++++++++- 3 files changed, 99 insertions(+), 8 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 7189954ffc20..d17fad6947b9 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,12 +1,20 @@ # coding: utf-8 """Callbacks library.""" import collections -from operator import gt, lt +from functools import partial from typing import Any, Callable, Dict, List, Union from .basic import _ConfigAliases, _log_info, _log_warning +def _gt_threshold(curr_score, best_score, threshold): + return curr_score > best_score + threshold + + +def _lt_threshold(curr_score, best_score, threshold): + return curr_score < best_score - threshold + + class EarlyStopException(Exception): """Exception of early stopping.""" @@ -143,7 +151,7 @@ def _callback(env: CallbackEnv) -> None: return _callback -def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable: +def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, threshold: Union[float, List[float]] = 0.0) -> Callable: """Create a callback that activates early stopping. Activates early stopping. @@ -162,6 +170,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos Whether to use only the first metric for early stopping. verbose : bool, optional (default=True) Whether to print message with early stopping information. + threshold: float or list of float (default=0.0) + Minimum improvement in score to keep training. Returns ------- @@ -188,17 +198,27 @@ def _init(env: CallbackEnv) -> None: if verbose: _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") + n_metrics = len(env.evaluation_result_list) + if isinstance(threshold, float): + if n_metrics > 1: + _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') + tholds = [threshold] * n_metrics + else: + if len(threshold) != n_metrics: + raise ValueError('Must provide a single early stopping threshold or as many as metrics.') + tholds = threshold + # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] - for eval_ret in env.evaluation_result_list: + for i, eval_ret in enumerate(env.evaluation_result_list): best_iter.append(0) best_score_list.append(None) - if eval_ret[3]: + if eval_ret[3]: # greater is better best_score.append(float('-inf')) - cmp_op.append(gt) + cmp_op.append(partial(_gt_threshold, threshold=tholds[i])) else: best_score.append(float('inf')) - cmp_op.append(lt) + cmp_op.append(partial(_lt_threshold, threshold=tholds[i])) def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: if env.iteration == env.end_iteration - 1: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 0d278b21fc49..6c63c274a67a 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -34,6 +34,7 @@ def train( feature_name: Union[List[str], str] = 'auto', categorical_feature: Union[List[str], List[int], str] = 'auto', early_stopping_rounds: Optional[int] = None, + early_stopping_threshold: Union[float, List[float]] = 0.0, evals_result: Optional[Dict[str, Any]] = None, verbose_eval: Union[bool, int] = True, learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None, @@ -121,6 +122,8 @@ def train( To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``. The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``. + early_stopping_threshold : float or list of float (default=0.0) + Minimum improvement in score to keep training. evals_result: dict or None, optional (default=None) Dictionary used to store all evaluation results of all the items in ``valid_sets``. This should be initialized outside of your call to ``train()`` and should be empty. @@ -239,7 +242,7 @@ def train( callbacks.add(callback.print_evaluation(verbose_eval)) if early_stopping_rounds is not None and early_stopping_rounds > 0: - callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval))) + callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval), threshold=early_stopping_threshold)) if learning_rates is not None: callbacks.add(callback.reset_parameter(learning_rate=learning_rates)) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b44cee469a22..5832b57457bd 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -11,7 +11,7 @@ import psutil import pytest from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr -from sklearn.datasets import load_svmlight_file, make_multilabel_classification +from sklearn.datasets import load_svmlight_file, make_classification, make_multilabel_classification from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split @@ -642,6 +642,74 @@ def test_early_stopping(): assert 'binary_logloss' in gbm.best_score[valid_set_name] +@pytest.mark.parametrize('first_only', [True, False]) +@pytest.mark.parametrize('single_metric', [True, False]) +@pytest.mark.parametrize('greater_is_better', [True, False]) +def test_early_stopping_threshold(single_metric, first_only, greater_is_better): + metric2threshold = { + 'auc': 0.001, + 'binary_logloss': 0.01, + 'average_precision': 0.001, + 'l2': 0.001, + } + if single_metric: + if greater_is_better: + metric = ['auc'] + else: + metric = ['binary_logloss'] + else: + if first_only: + if greater_is_better: + metric = ['auc', 'binary_logloss'] + else: + metric = ['binary_logloss', 'auc'] + else: + if greater_is_better: + metric = ['auc', 'average_precision'] + else: + metric = ['binary_logloss', 'l2'] + + X, y = make_classification(n_samples=1_000, n_features=2, n_redundant=0, n_classes=2, random_state=0) + X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0) + train_ds = lgb.Dataset(X_train, y_train) + valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds) + + params = {'objective': 'binary', 'metric': metric, 'first_metric_only': first_only, 'verbose': -1} + threshold = [metric2threshold[m] for m in metric] + train_kwargs = dict( + params=params, + train_set=train_ds, + num_boost_round=100, + valid_sets=[valid_ds], + early_stopping_rounds=10, + verbose_eval=0, + ) + + # regular early stopping + evals_result = {} + bst = lgb.train(evals_result=evals_result, **train_kwargs) + scores = np.vstack([res for res in evals_result['valid_0'].values()]).T + + # positive threshold + threshold_result = {} + threshold_bst = lgb.train(early_stopping_threshold=threshold, evals_result=threshold_result, **train_kwargs) + threshold_scores = np.vstack([res for res in threshold_result['valid_0'].values()]).T + + if first_only: + threshold = threshold[0] + scores = scores[:, 0] + threshold_scores = threshold_scores[:, 0] + + assert threshold_bst.num_trees() < bst.num_trees() + np.testing.assert_equal(scores[:len(threshold_scores)], threshold_scores) + last_score = threshold_scores[-1] + best_score = threshold_scores[threshold_bst.num_trees() - 1] + if greater_is_better: + assert np.less_equal(last_score, best_score + threshold).any() + else: + assert np.greater_equal(last_score, best_score - threshold).any() + + def test_continue_train(): X, y = load_boston(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) From dc0e62143086a2f2c9b454d29134e970446ffb17 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Tue, 31 Aug 2021 20:16:36 -0500 Subject: [PATCH 02/13] initial version --- python-package/lightgbm/callback.py | 12 ++++++---- python-package/lightgbm/engine.py | 8 ++++--- python-package/lightgbm/sklearn.py | 30 ++++++++++++++---------- tests/python_package_test/test_engine.py | 4 +++- 4 files changed, 33 insertions(+), 21 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index d17fad6947b9..68c574404775 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -199,14 +199,16 @@ def _init(env: CallbackEnv) -> None: _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") n_metrics = len(env.evaluation_result_list) - if isinstance(threshold, float): - if n_metrics > 1: - _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') - tholds = [threshold] * n_metrics - else: + if isinstance(threshold, list): if len(threshold) != n_metrics: raise ValueError('Must provide a single early stopping threshold or as many as metrics.') tholds = threshold + else: + if n_metrics > 1 and threshold > 0: + _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') + tholds = [threshold] * n_metrics + if not all(t >= 0 for t in tholds): + raise ValueError('Early stopping threshold must be non-negative.') # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 6c63c274a67a..52bb72de9e75 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -424,8 +424,8 @@ def cv(params, train_set, num_boost_round=100, folds=None, nfold=5, stratified=True, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', - early_stopping_rounds=None, fpreproc=None, - verbose_eval=None, show_stdv=True, seed=0, + early_stopping_rounds=None, early_stopping_threshold=0.0, + fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, callbacks=None, eval_train_metric=False, return_cvbooster=False): """Perform the cross-validation with given parameters. @@ -518,6 +518,8 @@ def cv(params, train_set, num_boost_round=100, Requires at least one metric. If there's more than one, will check all of them. To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``. Last entry in evaluation history is the one from the best iteration. + early_stopping_threshold : float or list of float (default=0.0) + Minimum improvement in score to keep training. fpreproc : callable or None, optional (default=None) Preprocessing function that takes (dtrain, dtest, params) and returns transformed versions of those. @@ -603,7 +605,7 @@ def cv(params, train_set, num_boost_round=100, cb.__dict__.setdefault('order', i - len(callbacks)) callbacks = set(callbacks) if early_stopping_rounds is not None and early_stopping_rounds > 0: - callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False)) + callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False, threshold=early_stopping_threshold)) if verbose_eval is True: callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) elif isinstance(verbose_eval, int): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 565ed8c10c9d..3a8e7e56c9bc 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -230,6 +230,8 @@ def __call__(self, preds, dataset): If there's more than one, will check all of them. But the training data is ignored anyway. To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in additional parameters ``**kwargs`` of the model constructor. + early_stopping_threshold : float or list of float (default=0.0) + Minimum improvement in score to keep training. verbose : bool or int, optional (default=True) Requires at least one evaluation data. If True, the eval metric on the eval set is printed at each boosting stage. @@ -570,8 +572,8 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_group=None, - eval_metric=None, early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', + eval_metric=None, early_stopping_rounds=None, early_stopping_threshold=0.0, + verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is set after definition, using a template.""" if self._objective is None: @@ -711,7 +713,7 @@ def _get_meta_data(collection, name, i): self._Booster = train(params, train_set, self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, - early_stopping_rounds=early_stopping_rounds, + early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable, verbose_eval=verbose, feature_name=feature_name, callbacks=callbacks, init_model=init_model) @@ -843,13 +845,14 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_metric=None, early_stopping_rounds=None, - verbose=True, feature_name='auto', categorical_feature='auto', - callbacks=None, init_model=None): + early_stopping_threshold=0.0, verbose=True, feature_name='auto', + categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" super().fit(X, y, sample_weight=sample_weight, init_score=init_score, eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_init_score=eval_init_score, eval_metric=eval_metric, - early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name, + early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, + verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self @@ -869,8 +872,8 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, - feature_name='auto', categorical_feature='auto', + early_stopping_rounds=None, early_stopping_threshold=0.0, + verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" _LGBMAssertAllFinite(y) @@ -922,7 +925,8 @@ def fit(self, X, y, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_class_weight=eval_class_weight, eval_init_score=eval_init_score, eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds, - verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, + early_stopping_threshold=early_stopping_threshold, verbose=verbose, + feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self @@ -997,7 +1001,8 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric=None, - eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True, + eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, + early_stopping_threshold=0.0, verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" @@ -1021,8 +1026,9 @@ def fit(self, X, y, super().fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group, eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_init_score=eval_init_score, eval_group=eval_group, eval_metric=eval_metric, - early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name, - categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) + early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, + verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, + callbacks=callbacks, init_model=init_model) return self _base_doc = LGBMModel.fit.__doc__ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 5832b57457bd..bec477c08dbc 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -646,6 +646,8 @@ def test_early_stopping(): @pytest.mark.parametrize('single_metric', [True, False]) @pytest.mark.parametrize('greater_is_better', [True, False]) def test_early_stopping_threshold(single_metric, first_only, greater_is_better): + if single_metric and not first_only: + pytest.skip("first_metric_only doesn't affect single metric.") metric2threshold = { 'auc': 0.001, 'binary_logloss': 0.01, @@ -701,7 +703,7 @@ def test_early_stopping_threshold(single_metric, first_only, greater_is_better): threshold_scores = threshold_scores[:, 0] assert threshold_bst.num_trees() < bst.num_trees() - np.testing.assert_equal(scores[:len(threshold_scores)], threshold_scores) + np.testing.assert_allclose(scores[:len(threshold_scores)], threshold_scores) last_score = threshold_scores[-1] best_score = threshold_scores[threshold_bst.num_trees() - 1] if greater_is_better: From 6ff7ad3010f03476158f9889b758c62c69a0dac2 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Tue, 31 Aug 2021 21:00:24 -0500 Subject: [PATCH 03/13] better handling of cases --- python-package/lightgbm/callback.py | 22 +++++++++++++-------- tests/python_package_test/test_engine.py | 25 ++++++++++++++---------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 68c574404775..14a4f95fb7a8 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -198,17 +198,23 @@ def _init(env: CallbackEnv) -> None: if verbose: _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") - n_metrics = len(env.evaluation_result_list) + n_metrics = len(set(m[1] for m in env.evaluation_result_list)) + n_datasets = len(env.evaluation_result_list) // n_metrics if isinstance(threshold, list): - if len(threshold) != n_metrics: - raise ValueError('Must provide a single early stopping threshold or as many as metrics.') - tholds = threshold + if not all(t >= 0 for t in threshold): + raise ValueError('Early stopping thresholds must be non-negative.') + if len(threshold) > 1: + if len(threshold) != n_metrics: + raise ValueError('Must provide a single early stopping threshold or as many as metrics.') + if first_metric_only: + _log_warning(f'Using only {threshold[0]} as early stopping threshold.') + tholds = threshold * n_datasets else: - if n_metrics > 1 and threshold > 0: + if threshold < 0: + raise ValueError('Early stopping threshold must be non-negative.') + if n_metrics > 1 and not first_metric_only: _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') - tholds = [threshold] * n_metrics - if not all(t >= 0 for t in tholds): - raise ValueError('Early stopping threshold must be non-negative.') + tholds = [threshold] * len(env.evaluation_result_list) # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index bec477c08dbc..3d35232f9734 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -645,20 +645,20 @@ def test_early_stopping(): @pytest.mark.parametrize('first_only', [True, False]) @pytest.mark.parametrize('single_metric', [True, False]) @pytest.mark.parametrize('greater_is_better', [True, False]) -def test_early_stopping_threshold(single_metric, first_only, greater_is_better): +def test_early_stopping_threshold(first_only, single_metric, greater_is_better): if single_metric and not first_only: pytest.skip("first_metric_only doesn't affect single metric.") metric2threshold = { 'auc': 0.001, 'binary_logloss': 0.01, 'average_precision': 0.001, - 'l2': 0.001, + 'mape': 0.001, } if single_metric: if greater_is_better: - metric = ['auc'] + metric = 'auc' else: - metric = ['binary_logloss'] + metric = 'binary_logloss' else: if first_only: if greater_is_better: @@ -669,7 +669,7 @@ def test_early_stopping_threshold(single_metric, first_only, greater_is_better): if greater_is_better: metric = ['auc', 'average_precision'] else: - metric = ['binary_logloss', 'l2'] + metric = ['binary_logloss', 'mape'] X, y = make_classification(n_samples=1_000, n_features=2, n_redundant=0, n_classes=2, random_state=0) X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0) @@ -677,12 +677,18 @@ def test_early_stopping_threshold(single_metric, first_only, greater_is_better): valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds) params = {'objective': 'binary', 'metric': metric, 'first_metric_only': first_only, 'verbose': -1} - threshold = [metric2threshold[m] for m in metric] + if isinstance(metric, str): + threshold = metric2threshold[metric] + elif first_only: + threshold = metric2threshold[metric[0]] + else: + threshold = [metric2threshold[m] for m in metric] train_kwargs = dict( params=params, train_set=train_ds, num_boost_round=100, - valid_sets=[valid_ds], + valid_sets=[train_ds, valid_ds], + valid_names=['training', 'valid'], early_stopping_rounds=10, verbose_eval=0, ) @@ -690,15 +696,14 @@ def test_early_stopping_threshold(single_metric, first_only, greater_is_better): # regular early stopping evals_result = {} bst = lgb.train(evals_result=evals_result, **train_kwargs) - scores = np.vstack([res for res in evals_result['valid_0'].values()]).T + scores = np.vstack([res for res in evals_result['valid'].values()]).T # positive threshold threshold_result = {} threshold_bst = lgb.train(early_stopping_threshold=threshold, evals_result=threshold_result, **train_kwargs) - threshold_scores = np.vstack([res for res in threshold_result['valid_0'].values()]).T + threshold_scores = np.vstack([res for res in threshold_result['valid'].values()]).T if first_only: - threshold = threshold[0] scores = scores[:, 0] threshold_scores = threshold_scores[:, 0] From 4b3aa351e568a140f6456c15ae3e6b50d57cdb3e Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Tue, 31 Aug 2021 21:32:25 -0500 Subject: [PATCH 04/13] warn only with positive threshold --- python-package/lightgbm/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 14a4f95fb7a8..1221c007450d 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -212,7 +212,7 @@ def _init(env: CallbackEnv) -> None: else: if threshold < 0: raise ValueError('Early stopping threshold must be non-negative.') - if n_metrics > 1 and not first_metric_only: + if threshold > 0 and n_metrics > 1 and not first_metric_only: _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') tholds = [threshold] * len(env.evaluation_result_list) From 830961b43c60084d95da5c8a52b5aef05d784fa3 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Sat, 4 Sep 2021 21:42:00 -0500 Subject: [PATCH 05/13] remove early_stopping_threshold from high-level functions --- python-package/lightgbm/callback.py | 4 ++-- python-package/lightgbm/engine.py | 13 ++++--------- python-package/lightgbm/sklearn.py | 24 ++++++++++-------------- 3 files changed, 16 insertions(+), 25 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 1221c007450d..fc03d0c09cfc 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -7,11 +7,11 @@ from .basic import _ConfigAliases, _log_info, _log_warning -def _gt_threshold(curr_score, best_score, threshold): +def _gt_threshold(curr_score: float, best_score: float, threshold: float) -> bool: return curr_score > best_score + threshold -def _lt_threshold(curr_score, best_score, threshold): +def _lt_threshold(curr_score: float, best_score: float, threshold: float) -> bool: return curr_score < best_score - threshold diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 52bb72de9e75..0d278b21fc49 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -34,7 +34,6 @@ def train( feature_name: Union[List[str], str] = 'auto', categorical_feature: Union[List[str], List[int], str] = 'auto', early_stopping_rounds: Optional[int] = None, - early_stopping_threshold: Union[float, List[float]] = 0.0, evals_result: Optional[Dict[str, Any]] = None, verbose_eval: Union[bool, int] = True, learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None, @@ -122,8 +121,6 @@ def train( To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``. The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``. - early_stopping_threshold : float or list of float (default=0.0) - Minimum improvement in score to keep training. evals_result: dict or None, optional (default=None) Dictionary used to store all evaluation results of all the items in ``valid_sets``. This should be initialized outside of your call to ``train()`` and should be empty. @@ -242,7 +239,7 @@ def train( callbacks.add(callback.print_evaluation(verbose_eval)) if early_stopping_rounds is not None and early_stopping_rounds > 0: - callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval), threshold=early_stopping_threshold)) + callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval))) if learning_rates is not None: callbacks.add(callback.reset_parameter(learning_rate=learning_rates)) @@ -424,8 +421,8 @@ def cv(params, train_set, num_boost_round=100, folds=None, nfold=5, stratified=True, shuffle=True, metrics=None, fobj=None, feval=None, init_model=None, feature_name='auto', categorical_feature='auto', - early_stopping_rounds=None, early_stopping_threshold=0.0, - fpreproc=None, verbose_eval=None, show_stdv=True, seed=0, + early_stopping_rounds=None, fpreproc=None, + verbose_eval=None, show_stdv=True, seed=0, callbacks=None, eval_train_metric=False, return_cvbooster=False): """Perform the cross-validation with given parameters. @@ -518,8 +515,6 @@ def cv(params, train_set, num_boost_round=100, Requires at least one metric. If there's more than one, will check all of them. To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``. Last entry in evaluation history is the one from the best iteration. - early_stopping_threshold : float or list of float (default=0.0) - Minimum improvement in score to keep training. fpreproc : callable or None, optional (default=None) Preprocessing function that takes (dtrain, dtest, params) and returns transformed versions of those. @@ -605,7 +600,7 @@ def cv(params, train_set, num_boost_round=100, cb.__dict__.setdefault('order', i - len(callbacks)) callbacks = set(callbacks) if early_stopping_rounds is not None and early_stopping_rounds > 0: - callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False, threshold=early_stopping_threshold)) + callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False)) if verbose_eval is True: callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) elif isinstance(verbose_eval, int): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 3a8e7e56c9bc..e47f4d000659 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -713,7 +713,7 @@ def _get_meta_data(collection, name, i): self._Booster = train(params, train_set, self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, - early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, + early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable, verbose_eval=verbose, feature_name=feature_name, callbacks=callbacks, init_model=init_model) @@ -845,14 +845,13 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_metric=None, early_stopping_rounds=None, - early_stopping_threshold=0.0, verbose=True, feature_name='auto', - categorical_feature='auto', callbacks=None, init_model=None): + verbose=True, feature_name='auto', categorical_feature='auto', + callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" super().fit(X, y, sample_weight=sample_weight, init_score=init_score, eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_init_score=eval_init_score, eval_metric=eval_metric, - early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, - verbose=verbose, feature_name=feature_name, + early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self @@ -872,8 +871,8 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_metric=None, - early_stopping_rounds=None, early_stopping_threshold=0.0, - verbose=True, feature_name='auto', categorical_feature='auto', + early_stopping_rounds=None, verbose=True, + feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" _LGBMAssertAllFinite(y) @@ -924,8 +923,7 @@ def fit(self, X, y, super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_class_weight=eval_class_weight, eval_init_score=eval_init_score, - eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds, - early_stopping_threshold=early_stopping_threshold, verbose=verbose, + eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self @@ -1001,8 +999,7 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric=None, - eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, - early_stopping_threshold=0.0, verbose=True, + eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" @@ -1026,9 +1023,8 @@ def fit(self, X, y, super().fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group, eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_init_score=eval_init_score, eval_group=eval_group, eval_metric=eval_metric, - early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold, - verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, - callbacks=callbacks, init_model=init_model) + early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name, + categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self _base_doc = LGBMModel.fit.__doc__ From 71d379f15bf3289dac9fdb08968f0e67649bdd11 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Sat, 4 Sep 2021 21:47:50 -0500 Subject: [PATCH 06/13] remove remaining early_stopping_threshold --- python-package/lightgbm/sklearn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 5e68d1033691..c59582464f88 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -230,8 +230,6 @@ def __call__(self, preds, dataset): If there's more than one, will check all of them. But the training data is ignored anyway. To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in additional parameters ``**kwargs`` of the model constructor. - early_stopping_threshold : float or list of float (default=0.0) - Minimum improvement in score to keep training. verbose : bool or int, optional (default=True) Requires at least one evaluation data. If True, the eval metric on the eval set is printed at each boosting stage. @@ -572,8 +570,8 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_group=None, - eval_metric=None, early_stopping_rounds=None, early_stopping_threshold=0.0, - verbose=True, feature_name='auto', categorical_feature='auto', + eval_metric=None, early_stopping_rounds=None, verbose=True, + feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is set after definition, using a template.""" if self._objective is None: @@ -929,8 +927,8 @@ def fit(self, X, y, super().fit(X, _y, sample_weight=sample_weight, init_score=init_score, eval_set=valid_sets, eval_names=eval_names, eval_sample_weight=eval_sample_weight, eval_class_weight=eval_class_weight, eval_init_score=eval_init_score, - eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds, verbose=verbose, - feature_name=feature_name, categorical_feature=categorical_feature, + eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds, + verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model) return self From 8db30b66464e8793b766d690315dca23d8b3119a Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Sat, 4 Sep 2021 22:06:24 -0500 Subject: [PATCH 07/13] update test to use callback --- tests/python_package_test/test_engine.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3d35232f9734..5a66afbee8de 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -11,7 +11,7 @@ import psutil import pytest from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr -from sklearn.datasets import load_svmlight_file, make_classification, make_multilabel_classification +from sklearn.datasets import load_svmlight_file, make_multilabel_classification from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split @@ -652,7 +652,7 @@ def test_early_stopping_threshold(first_only, single_metric, greater_is_better): 'auc': 0.001, 'binary_logloss': 0.01, 'average_precision': 0.001, - 'mape': 0.001, + 'mape': 0.01, } if single_metric: if greater_is_better: @@ -671,12 +671,12 @@ def test_early_stopping_threshold(first_only, single_metric, greater_is_better): else: metric = ['binary_logloss', 'mape'] - X, y = make_classification(n_samples=1_000, n_features=2, n_redundant=0, n_classes=2, random_state=0) + X, y = load_breast_cancer(return_X_y=True) X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0) train_ds = lgb.Dataset(X_train, y_train) valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds) - params = {'objective': 'binary', 'metric': metric, 'first_metric_only': first_only, 'verbose': -1} + params = {'objective': 'binary', 'metric': metric, 'verbose': -1} if isinstance(metric, str): threshold = metric2threshold[metric] elif first_only: @@ -689,18 +689,18 @@ def test_early_stopping_threshold(first_only, single_metric, greater_is_better): num_boost_round=100, valid_sets=[train_ds, valid_ds], valid_names=['training', 'valid'], - early_stopping_rounds=10, - verbose_eval=0, ) # regular early stopping + train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0)] evals_result = {} bst = lgb.train(evals_result=evals_result, **train_kwargs) scores = np.vstack([res for res in evals_result['valid'].values()]).T # positive threshold + train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, threshold=threshold)] threshold_result = {} - threshold_bst = lgb.train(early_stopping_threshold=threshold, evals_result=threshold_result, **train_kwargs) + threshold_bst = lgb.train(evals_result=threshold_result, **train_kwargs) threshold_scores = np.vstack([res for res in threshold_result['valid'].values()]).T if first_only: From 9ec9891ce9ca86d5eb84f8d8650d6f2cd3117ea7 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Sat, 4 Sep 2021 22:34:29 -0500 Subject: [PATCH 08/13] better handling of cases --- python-package/lightgbm/callback.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 1bbb4f0b8773..0e0b8ed5ac68 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -205,18 +205,24 @@ def _init(env: CallbackEnv) -> None: if isinstance(threshold, list): if not all(t >= 0 for t in threshold): raise ValueError('Early stopping thresholds must be non-negative.') - if len(threshold) > 1: + if len(threshold) == 0: + _log_warning('Disabling threshold for early stopping.') + tholds = [0.0] * n_datasets * n_metrics + elif len(threshold) == 1: + _log_warning(f'Using {threshold[0]} as threshold for all metrics.') + tholds = threshold * n_datasets * n_metrics + else: if len(threshold) != n_metrics: raise ValueError('Must provide a single early stopping threshold or as many as metrics.') if first_metric_only: _log_warning(f'Using only {threshold[0]} as early stopping threshold.') - tholds = threshold * n_datasets + tholds = threshold * n_datasets else: if threshold < 0: raise ValueError('Early stopping threshold must be non-negative.') if threshold > 0 and n_metrics > 1 and not first_metric_only: - _log_warning(f'Using {threshold} as the early stopping threshold for all metrics.') - tholds = [threshold] * len(env.evaluation_result_list) + _log_warning(f'Using {threshold} as threshold for all metrics.') + tholds = [threshold] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] From b6d3432e152a107246e0240142a11d4665227497 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Sat, 25 Sep 2021 10:46:07 -0500 Subject: [PATCH 09/13] rename threshold to min_delta enhance parameter description update tests --- python-package/lightgbm/callback.py | 58 ++++++++++++------------ tests/python_package_test/test_engine.py | 34 +++++++------- 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 0e0b8ed5ac68..458860f1aaff 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -7,12 +7,12 @@ from .basic import _ConfigAliases, _log_info, _log_warning -def _gt_threshold(curr_score: float, best_score: float, threshold: float) -> bool: - return curr_score > best_score + threshold +def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool: + return curr_score > best_score + delta -def _lt_threshold(curr_score: float, best_score: float, threshold: float) -> bool: - return curr_score < best_score - threshold +def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool: + return curr_score < best_score - delta class EarlyStopException(Exception): @@ -153,11 +153,11 @@ def _callback(env: CallbackEnv) -> None: return _callback -def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, threshold: Union[float, List[float]] = 0.0) -> Callable: +def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable: """Create a callback that activates early stopping. Activates early stopping. - The model will train until the validation score stops improving. + The model will train until the validation score doesn't improve by at least ``min_delta``. Validation score needs to improve at least every ``early_stopping_rounds`` round(s) to continue training. Requires at least one validation data and one metric. @@ -172,8 +172,10 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos Whether to use only the first metric for early stopping. verbose : bool, optional (default=True) Whether to print message with early stopping information. - threshold: float or list of float (default=0.0) + min_delta: float or list of float (default=0.0) Minimum improvement in score to keep training. + If float, this single value is used for all metrics. + If list, its length should match the total number of metrics. Returns ------- @@ -202,39 +204,39 @@ def _init(env: CallbackEnv) -> None: n_metrics = len(set(m[1] for m in env.evaluation_result_list)) n_datasets = len(env.evaluation_result_list) // n_metrics - if isinstance(threshold, list): - if not all(t >= 0 for t in threshold): - raise ValueError('Early stopping thresholds must be non-negative.') - if len(threshold) == 0: - _log_warning('Disabling threshold for early stopping.') - tholds = [0.0] * n_datasets * n_metrics - elif len(threshold) == 1: - _log_warning(f'Using {threshold[0]} as threshold for all metrics.') - tholds = threshold * n_datasets * n_metrics + if isinstance(min_delta, list): + if not all(t >= 0 for t in min_delta): + raise ValueError('Values for early stopping min_delta must be non-negative.') + if len(min_delta) == 0: + _log_info('Disabling min_delta for early stopping.') + deltas = [0.0] * n_datasets * n_metrics + elif len(min_delta) == 1: + _log_info(f'Using {min_delta[0]} as min_delta for all metrics.') + deltas = min_delta * n_datasets * n_metrics else: - if len(threshold) != n_metrics: - raise ValueError('Must provide a single early stopping threshold or as many as metrics.') + if len(min_delta) != n_metrics: + raise ValueError('Must provide a single value for min_delta or as many as metrics.') if first_metric_only: - _log_warning(f'Using only {threshold[0]} as early stopping threshold.') - tholds = threshold * n_datasets + _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') + deltas = min_delta * n_datasets else: - if threshold < 0: - raise ValueError('Early stopping threshold must be non-negative.') - if threshold > 0 and n_metrics > 1 and not first_metric_only: - _log_warning(f'Using {threshold} as threshold for all metrics.') - tholds = [threshold] * n_datasets * n_metrics + if min_delta < 0: + raise ValueError('Early stopping min_delta must be non-negative.') + if min_delta > 0 and n_metrics > 1 and not first_metric_only: + _log_info(f'Using {min_delta} as min_delta for all metrics.') + deltas = [min_delta] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] - for i, eval_ret in enumerate(env.evaluation_result_list): + for eval_ret, delta in zip(env.evaluation_result_list, deltas): best_iter.append(0) best_score_list.append(None) if eval_ret[3]: # greater is better best_score.append(float('-inf')) - cmp_op.append(partial(_gt_threshold, threshold=tholds[i])) + cmp_op.append(partial(_gt_delta, delta=delta)) else: best_score.append(float('inf')) - cmp_op.append(partial(_lt_threshold, threshold=tholds[i])) + cmp_op.append(partial(_lt_delta, delta=delta)) def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: if env.iteration == env.end_iteration - 1: diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 5a66afbee8de..72ae73e3fecd 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -645,10 +645,10 @@ def test_early_stopping(): @pytest.mark.parametrize('first_only', [True, False]) @pytest.mark.parametrize('single_metric', [True, False]) @pytest.mark.parametrize('greater_is_better', [True, False]) -def test_early_stopping_threshold(first_only, single_metric, greater_is_better): +def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): if single_metric and not first_only: pytest.skip("first_metric_only doesn't affect single metric.") - metric2threshold = { + metric2min_delta = { 'auc': 0.001, 'binary_logloss': 0.01, 'average_precision': 0.001, @@ -678,11 +678,11 @@ def test_early_stopping_threshold(first_only, single_metric, greater_is_better): params = {'objective': 'binary', 'metric': metric, 'verbose': -1} if isinstance(metric, str): - threshold = metric2threshold[metric] + min_delta = metric2min_delta[metric] elif first_only: - threshold = metric2threshold[metric[0]] + min_delta = metric2min_delta[metric[0]] else: - threshold = [metric2threshold[m] for m in metric] + min_delta = [metric2min_delta[m] for m in metric] train_kwargs = dict( params=params, train_set=train_ds, @@ -697,24 +697,24 @@ def test_early_stopping_threshold(first_only, single_metric, greater_is_better): bst = lgb.train(evals_result=evals_result, **train_kwargs) scores = np.vstack([res for res in evals_result['valid'].values()]).T - # positive threshold - train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, threshold=threshold)] - threshold_result = {} - threshold_bst = lgb.train(evals_result=threshold_result, **train_kwargs) - threshold_scores = np.vstack([res for res in threshold_result['valid'].values()]).T + # positive min_delta + train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta)] + delta_result = {} + delta_bst = lgb.train(evals_result=delta_result, **train_kwargs) + delta_scores = np.vstack([res for res in delta_result['valid'].values()]).T if first_only: scores = scores[:, 0] - threshold_scores = threshold_scores[:, 0] + delta_scores = delta_scores[:, 0] - assert threshold_bst.num_trees() < bst.num_trees() - np.testing.assert_allclose(scores[:len(threshold_scores)], threshold_scores) - last_score = threshold_scores[-1] - best_score = threshold_scores[threshold_bst.num_trees() - 1] + assert delta_bst.num_trees() < bst.num_trees() + np.testing.assert_allclose(scores[:len(delta_scores)], delta_scores) + last_score = delta_scores[-1] + best_score = delta_scores[delta_bst.num_trees() - 1] if greater_is_better: - assert np.less_equal(last_score, best_score + threshold).any() + assert np.less_equal(last_score, best_score + min_delta).any() else: - assert np.greater_equal(last_score, best_score - threshold).any() + assert np.greater_equal(last_score, best_score - min_delta).any() def test_continue_train(): From 8b6e013d0de93473b6855e698c7117c1e3b15b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 14 Oct 2021 10:47:05 -0500 Subject: [PATCH 10/13] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/callback.py | 14 +++++++++----- tests/python_package_test/test_engine.py | 4 ++-- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 07a5398aed04..6eaee24b2d7e 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -211,7 +211,7 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos Whether to log message with early stopping information. By default, standard output resource is used. Use ``register_logger()`` function to register a custom logger. - min_delta: float or list of float (default=0.0) + min_delta : float or list of float, optional (default=0.0) Minimum improvement in score to keep training. If float, this single value is used for all metrics. If list, its length should match the total number of metrics. @@ -247,22 +247,26 @@ def _init(env: CallbackEnv) -> None: if not all(t >= 0 for t in min_delta): raise ValueError('Values for early stopping min_delta must be non-negative.') if len(min_delta) == 0: - _log_info('Disabling min_delta for early stopping.') + if verbose: + _log_info('Disabling min_delta for early stopping.') deltas = [0.0] * n_datasets * n_metrics elif len(min_delta) == 1: - _log_info(f'Using {min_delta[0]} as min_delta for all metrics.') + if verbose: + _log_info(f'Using {min_delta[0]} as min_delta for all metrics.') deltas = min_delta * n_datasets * n_metrics else: if len(min_delta) != n_metrics: raise ValueError('Must provide a single value for min_delta or as many as metrics.') if first_metric_only: - _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') + if verbose: + _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') deltas = min_delta * n_datasets else: if min_delta < 0: raise ValueError('Early stopping min_delta must be non-negative.') if min_delta > 0 and n_metrics > 1 and not first_metric_only: - _log_info(f'Using {min_delta} as min_delta for all metrics.') + if verbose: + _log_info(f'Using {min_delta} as min_delta for all metrics.') deltas = [min_delta] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 72ae73e3fecd..f16320f81a5d 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -695,13 +695,13 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0)] evals_result = {} bst = lgb.train(evals_result=evals_result, **train_kwargs) - scores = np.vstack([res for res in evals_result['valid'].values()]).T + scores = np.vstack(list(evals_result['valid'].values())).T # positive min_delta train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta)] delta_result = {} delta_bst = lgb.train(evals_result=delta_result, **train_kwargs) - delta_scores = np.vstack([res for res in delta_result['valid'].values()]).T + delta_scores = np.vstack(list(delta_result['valid'].values())).T if first_only: scores = scores[:, 0] From b73f3f2398565567b8c7f2b35704113b31cdb3e8 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 14 Oct 2021 10:54:22 -0500 Subject: [PATCH 11/13] reduce num_boost_round in tests --- tests/python_package_test/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index f16320f81a5d..925b3f81c531 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -686,7 +686,7 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): train_kwargs = dict( params=params, train_set=train_ds, - num_boost_round=100, + num_boost_round=50, valid_sets=[train_ds, valid_ds], valid_names=['training', 'valid'], ) From 7fa8f5f59d3d6a938dc5fd698596256d1f040ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Thu, 14 Oct 2021 15:10:15 -0500 Subject: [PATCH 12/13] Apply suggestions from code review Co-authored-by: Nikita Titov --- python-package/lightgbm/callback.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 6eaee24b2d7e..5b4e3f194c31 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -257,16 +257,14 @@ def _init(env: CallbackEnv) -> None: else: if len(min_delta) != n_metrics: raise ValueError('Must provide a single value for min_delta or as many as metrics.') - if first_metric_only: - if verbose: - _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') + if first_metric_only and verbose: + _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') deltas = min_delta * n_datasets else: if min_delta < 0: raise ValueError('Early stopping min_delta must be non-negative.') - if min_delta > 0 and n_metrics > 1 and not first_metric_only: - if verbose: - _log_info(f'Using {min_delta} as min_delta for all metrics.') + if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose: + _log_info(f'Using {min_delta} as min_delta for all metrics.') deltas = [min_delta] * n_datasets * n_metrics # split is needed for " " case (e.g. "train l1") From 972dba56b2d6e07c643497d503dfec48a8136cad Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Fri, 22 Oct 2021 09:54:05 -0500 Subject: [PATCH 13/13] trigger ci