Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define best_iteration only if early stopping is used. #9403

Merged
merged 3 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions demo/guide-python/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
'''
"""
Demo for using and defining callback functions
==============================================

.. versionadded:: 1.3.0
'''
"""
import argparse
import os
import tempfile
Expand All @@ -17,10 +17,11 @@


class Plotting(xgb.callback.TrainingCallback):
'''Plot evaluation result during training. Only for demonstration purpose as it's quite
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.

'''
"""

def __init__(self, rounds):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
Expand All @@ -31,16 +32,16 @@ def __init__(self, rounds):
plt.ion()

def _get_key(self, data, metric):
return f'{data}-{metric}'
return f"{data}-{metric}"

def after_iteration(self, model, epoch, evals_log):
'''Update the plot.'''
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key], = self.ax.plot(self.x, expanded, label=key)
(self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
Expand All @@ -55,8 +56,8 @@ def after_iteration(self, model, epoch, evals_log):


def custom_callback():
'''Demo for defining a custom callback function that plots evaluation result during
training.'''
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)

Expand All @@ -69,15 +70,16 @@ def custom_callback():
# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
'objective': 'binary:logistic',
'eval_metric': ['error', 'rmse'],
'tree_method': 'hist',
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
"device": "cuda",
},
D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=num_boost_round,
callbacks=[plotting])
callbacks=[plotting],
)


def check_point_callback():
Expand All @@ -90,42 +92,47 @@ def check(as_pickle):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, 'model_' + str(i) + '.pkl')
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(tmpdir, 'model_' + str(i) + '.json')
assert(os.path.exists(path))
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
assert os.path.exists(path)

X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(False)

# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(directory=tmpdir,
iterations=rounds,
as_pickle=True,
name='model')
xgb.train({'objective': 'binary:logistic'}, m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point])
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(True)


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--plot', default=1, type=int)
parser.add_argument("--plot", default=1, type=int)
args = parser.parse_args()

check_point_callback()
Expand Down
4 changes: 4 additions & 0 deletions doc/python/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ The sliced model is a copy of selected trees, that means the model itself is imm
during slicing. This feature is the basis of `save_best` option in early stopping
callback. See :ref:`sphx_glr_python_examples_individual_trees.py` for a worked example on
how to combine prediction with sliced trees.

.. note::

The returned model slice doesn't contain attributes like :py:class:`~xgboost.Booster.best_iteration` and :py:class:`~xgboost.Booster.best_score`.
112 changes: 64 additions & 48 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,17 @@ def __init__(
is_cv: bool = False,
) -> None:
self.callbacks = set(callbacks)
if metric is not None:
msg = (
"metric must be callable object for monitoring. For "
+ "builtin metrics, passing them in training parameter"
+ " will invoke monitor automatically."
)
assert callable(metric), msg
for cb in callbacks:
if not isinstance(cb, TrainingCallback):
raise TypeError("callback must be an instance of `TrainingCallback`.")

msg = (
"metric must be callable object for monitoring. For builtin metrics"
", passing them in training parameter invokes monitor automatically."
)
if metric is not None and not callable(metric):
raise TypeError(msg)

self.metric = metric
self.history: TrainingCallback.EvalsLog = collections.OrderedDict()
self._output_margin = output_margin
Expand Down Expand Up @@ -170,16 +174,6 @@ def after_training(self, model: _Model) -> _Model:
else:
assert isinstance(model, Booster), msg

if not self.is_cv:
if model.attr("best_score") is not None:
model.best_score = float(cast(str, model.attr("best_score")))
model.best_iteration = int(cast(str, model.attr("best_iteration")))
else:
# Due to compatibility with version older than 1.4, these attributes are
# added to Python object even if early stopping is not used.
model.best_iteration = model.num_boosted_rounds() - 1
model.set_attr(best_iteration=str(model.best_iteration))

return model

def before_iteration(
Expand Down Expand Up @@ -267,9 +261,14 @@ class LearningRateScheduler(TrainingCallback):
def __init__(
self, learning_rates: Union[Callable[[int], float], Sequence[float]]
) -> None:
assert callable(learning_rates) or isinstance(
if not callable(learning_rates) and not isinstance(
learning_rates, collections.abc.Sequence
)
):
raise TypeError(
"Invalid learning rates, expecting callable or sequence, got: "
f"{type(learning_rates)}"
)

if callable(learning_rates):
self.learning_rates = learning_rates
else:
Expand Down Expand Up @@ -302,24 +301,28 @@ class EarlyStopping(TrainingCallback):
save_best :
Whether training should return the best model or the last model.
min_delta :
Minimum absolute change in score to be qualified as an improvement.

.. versionadded:: 1.5.0

.. code-block:: python
Minimum absolute change in score to be qualified as an improvement.

es = xgboost.callback.EarlyStopping(
rounds=2,
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="gpu_hist", callbacks=[es])
Examples
--------

.. code-block:: python

X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)])
es = xgboost.callback.EarlyStopping(
rounds=2,
min_delta=1e-3,
save_best=True,
maximize=False,
data_name="validation_0",
metric_name="mlogloss",
)
clf = xgboost.XGBClassifier(tree_method="hist", device="cuda", callbacks=[es])

X, y = load_digits(return_X_y=True)
clf.fit(X, y, eval_set=[(X, y)])
"""

# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -363,7 +366,7 @@ def maximize(new: _Score, best: _Score) -> bool:
return numpy.greater(get_s(new) - self._min_delta, get_s(best))

def minimize(new: _Score, best: _Score) -> bool:
"""New score should be smaller than the old one."""
"""New score should be lesser than the old one."""
return numpy.greater(get_s(best) - self._min_delta, get_s(new))

if self.maximize is None:
Expand Down Expand Up @@ -419,38 +422,53 @@ def after_iteration(
) -> bool:
epoch += self.starting_round # training continuation
msg = "Must have at least 1 validation dataset for early stopping."
assert len(evals_log.keys()) >= 1, msg
data_name = ""
if len(evals_log.keys()) < 1:
raise ValueError(msg)

# Get data name
if self.data:
for d, _ in evals_log.items():
if d == self.data:
data_name = d
if not data_name:
raise ValueError("No dataset named:", self.data)
data_name = self.data
else:
# Use the last one as default.
data_name = list(evals_log.keys())[-1]
assert isinstance(data_name, str) and data_name
if data_name not in evals_log:
raise ValueError(f"No dataset named: {data_name}")

if not isinstance(data_name, str):
raise TypeError(
f"The name of the dataset should be a string. Got: {type(data_name)}"
)
data_log = evals_log[data_name]

# Filter out scores that can not be used for early stopping.
# Get metric name
if self.metric_name:
metric_name = self.metric_name
else:
# Use last metric by default.
assert isinstance(data_log, collections.OrderedDict)
metric_name = list(data_log.keys())[-1]
if metric_name not in data_log:
raise ValueError(f"No metric named: {metric_name}")

# The latest score
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)

def after_training(self, model: _Model) -> _Model:
if not self.save_best:
return model

try:
if self.save_best:
model = model[: int(model.attr("best_iteration")) + 1]
best_iteration = model.best_iteration
best_score = model.best_score
assert best_iteration is not None and best_score is not None
model = model[: best_iteration + 1]
model.best_iteration = best_iteration
model.best_score = best_score
except XGBoostError as e:
raise XGBoostError(
"`save_best` is not applicable to current booster"
"`save_best` is not applicable to the current booster"
) from e

return model


Expand All @@ -462,8 +480,6 @@ class EvaluationMonitor(TrainingCallback):
Parameters
----------

metric :
Extra user defined metric.
rank :
Which worker should be used for printing the result.
period :
Expand Down
35 changes: 30 additions & 5 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def attributes(self) -> Dict[str, Optional[str]]:
attr_names = from_cstr_to_pystr(sarr, length)
return {n: self.attr(n) for n in attr_names}

def set_attr(self, **kwargs: Optional[str]) -> None:
def set_attr(self, **kwargs: Optional[Any]) -> None:
"""Set the attribute of the Booster.

Parameters
Expand Down Expand Up @@ -2574,10 +2574,35 @@ def load_model(self, fname: ModelIn) -> None:
else:
raise TypeError("Unknown file type: ", fname)

if self.attr("best_iteration") is not None:
self.best_iteration = int(cast(int, self.attr("best_iteration")))
if self.attr("best_score") is not None:
self.best_score = float(cast(float, self.attr("best_score")))
@property
def best_iteration(self) -> int:
"""The best iteration during training."""
best = self.attr("best_iteration")
if best is not None:
return int(best)

raise AttributeError(
"`best_iteration` is only defined when early stopping is used."
)

@best_iteration.setter
def best_iteration(self, iteration: int) -> None:
self.set_attr(best_iteration=iteration)

@property
def best_score(self) -> float:
"""The best evaluation score during training."""
best = self.attr("best_score")
if best is not None:
return float(best)

raise AttributeError(
"`best_score` is only defined when early stopping is used."
)

@best_score.setter
def best_score(self, score: int) -> None:
self.set_attr(best_score=score)

def num_boosted_rounds(self) -> int:
"""Get number of boosted rounds. For gblinear this is reset to 0 after
Expand Down
Loading