Skip to content

Commit

Permalink
Merge pull request #354 from sebp/gb-early-stopping
Browse files Browse the repository at this point in the history
Add support for early-stopping to GradientBoostingSurvivalAnalysis
  • Loading branch information
sebp authored Apr 14, 2023
2 parents 52695a8 + f33600b commit 9c4c287
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 52 deletions.
222 changes: 171 additions & 51 deletions sksurv/ensemble/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.ensemble._base import BaseEnsemble
from sklearn.ensemble._gb import BaseGradientBoosting, VerboseReporter
from sklearn.ensemble._gradient_boosting import _random_sample_mask
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import DTYPE
from sklearn.utils import check_consistent_length, check_random_state, column_or_1d
Expand Down Expand Up @@ -554,6 +555,14 @@ class GradientBoostingSurvivalAnalysis(BaseGradientBoosting, SurvivalAnalysisMix
results in better performance.
Values must be in the range `[1, inf)`.
subsample : float, optional, default: 1.0
The fraction of samples to be used for fitting the individual base
learners. If smaller than 1.0 this results in Stochastic Gradient
Boosting. `subsample` interacts with the parameter `n_estimators`.
Choosing `subsample < 1.0` leads to a reduction of variance
and an increase in bias.
Values must be in the range `(0.0, 1.0]`.
criterion : string, optional, default: 'friedman_mse'
The function to measure the quality of a split. Supported criteria
are "friedman_mse" for the mean squared error with improvement
Expand Down Expand Up @@ -644,13 +653,26 @@ class GradientBoostingSurvivalAnalysis(BaseGradientBoosting, SurvivalAnalysisMix
Values must be in the range `[2, inf)`.
If `None`, then unlimited number of leaf nodes.
subsample : float, optional, default: 1.0
The fraction of samples to be used for fitting the individual base
learners. If smaller than 1.0 this results in Stochastic Gradient
Boosting. `subsample` interacts with the parameter `n_estimators`.
Choosing `subsample < 1.0` leads to a reduction of variance
and an increase in bias.
Values must be in the range `(0.0, 1.0]`.
validation_fraction : float, default: 0.1
The proportion of training data to set aside as validation set for
early stopping. Values must be in the range `(0.0, 1.0)`.
Only used if ``n_iter_no_change`` is set to an integer.
n_iter_no_change : int, default: None
``n_iter_no_change`` is used to decide if early stopping will be used
to terminate training when validation score is not improving. By
default it is set to None to disable early stopping. If set to a
number, it will set aside ``validation_fraction`` size of the training
data as validation and terminate training when validation score is not
improving in all of the previous ``n_iter_no_change`` numbers of
iterations. The split is stratified.
Values must be in the range `[1, inf)`.
tol : float, default: 1e-4
Tolerance for the early stopping. When the loss is not improving
by at least tol for ``n_iter_no_change`` iterations (if set to a
number), the training stops.
Values must be in the range `[0.0, inf)`.
dropout_rate : float, optional, default: 0.0
If larger than zero, the residuals at each iteration are only computed
Expand Down Expand Up @@ -728,32 +750,50 @@ class GradientBoostingSurvivalAnalysis(BaseGradientBoosting, SurvivalAnalysisMix
"dropout_rate": [Interval(numbers.Real, 0.0, 1.0, closed="left")]
}

def __init__(self, *, loss="coxph", learning_rate=0.1, n_estimators=100,
criterion='friedman_mse',
min_samples_split=2,
min_samples_leaf=1, min_weight_fraction_leaf=0.,
max_depth=3,
min_impurity_decrease=0., random_state=None,
max_features=None, max_leaf_nodes=None,
subsample=1.0, dropout_rate=0.0,
verbose=0,
ccp_alpha=0.0):
super().__init__(loss=loss,
learning_rate=learning_rate,
n_estimators=n_estimators,
subsample=subsample,
criterion=criterion,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth,
min_impurity_decrease=min_impurity_decrease,
init=None,
random_state=random_state,
max_features=max_features,
max_leaf_nodes=max_leaf_nodes,
verbose=verbose,
ccp_alpha=ccp_alpha)
def __init__(
self,
*,
loss="coxph",
learning_rate=0.1,
n_estimators=100,
subsample=1.0,
criterion="friedman_mse",
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.,
max_depth=3,
min_impurity_decrease=0.,
random_state=None,
max_features=None,
max_leaf_nodes=None,
validation_fraction=0.1,
n_iter_no_change=None,
tol=1e-4,
dropout_rate=0.0,
verbose=0,
ccp_alpha=0.0,
):
super().__init__(
loss=loss,
learning_rate=learning_rate,
n_estimators=n_estimators,
criterion=criterion,
min_samples_split=min_samples_split,
min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf,
max_depth=max_depth,
init=None,
subsample=subsample,
max_features=max_features,
random_state=random_state,
verbose=verbose,
max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease,
validation_fraction=validation_fraction,
n_iter_no_change=n_iter_no_change,
tol=tol,
ccp_alpha=ccp_alpha,
)
self.dropout_rate = dropout_rate

def _warn_mae_for_criterion(self):
Expand Down Expand Up @@ -794,8 +834,19 @@ def _check_max_features(self):

return max_features

def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask,
random_state, scale, X_csc=None, X_csr=None):
def _fit_stage(
self,
i,
X,
y,
raw_predictions,
sample_weight,
sample_mask,
random_state,
scale,
X_csc=None,
X_csr=None,
):
"""Fit another stage of ``n_classes_`` trees to the boosting model. """

assert sample_mask.dtype == bool
Expand Down Expand Up @@ -864,8 +915,19 @@ def _fit_stage(self, i, X, y, raw_predictions, sample_weight, sample_mask,

return raw_predictions

def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
begin_at_stage=0, monitor=None):
def _fit_stages( # noqa: C901
self,
X,
y,
raw_predictions,
sample_weight,
random_state,
X_val,
y_val,
sample_weight_val,
begin_at_stage=0,
monitor=None,
):
"""Iteratively fits the stages.
For each stage it computes the progress (OOB, train score)
Expand All @@ -891,6 +953,12 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
else:
scale = None

if self.n_iter_no_change is not None:
loss_history = np.full(self.n_iter_no_change, np.inf)
# We create a generator to get the predictions for X_val after
# the addition of each successive stage
y_val_pred_iter = self._staged_raw_predict(X_val, check_input=False)

# perform boosting iterations
i = begin_at_stage
for i in range(begin_at_stage, self.n_estimators):
Expand All @@ -901,24 +969,38 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
random_state)
# OOB score before adding this stage
y_oob_sample = y[~sample_mask]
old_oob_score = loss_(y_oob_sample,
raw_predictions[~sample_mask],
sample_weight[~sample_mask])
old_oob_score = loss_(
y_oob_sample,
raw_predictions[~sample_mask],
sample_weight[~sample_mask],
)

# fit next stage of trees
raw_predictions = self._fit_stage(
i, X, y, raw_predictions, sample_weight, sample_mask,
random_state, scale, X_csc, X_csr)
i,
X,
y,
raw_predictions,
sample_weight,
sample_mask,
random_state,
scale,
X_csc,
X_csr,
)

# track deviance (= loss)
if do_oob:
self.train_score_[i] = loss_(y[sample_mask],
raw_predictions[sample_mask],
sample_weight[sample_mask])
self.oob_improvement_[i] = (
old_oob_score - loss_(y_oob_sample,
raw_predictions[~sample_mask],
sample_weight[~sample_mask]))
self.train_score_[i] = loss_(
y[sample_mask],
raw_predictions[sample_mask],
sample_weight[sample_mask],
)
self.oob_improvement_[i] = old_oob_score - loss_(
y_oob_sample,
raw_predictions[~sample_mask],
sample_weight[~sample_mask],
)
else:
# no need to fancy index w/ no subsampling
self.train_score_[i] = loss_(y, raw_predictions, sample_weight)
Expand All @@ -931,6 +1013,20 @@ def _fit_stages(self, X, y, raw_predictions, sample_weight, random_state,
if early_stopping:
break

# We also provide an early stopping based on the score from
# validation set (X_val, y_val), if n_iter_no_change is set
if self.n_iter_no_change is not None:
# By calling next(y_val_pred_iter), we get the predictions
# for X_val after the addition of the current stage
validation_loss = loss_(y_val, next(y_val_pred_iter), sample_weight_val)

# Require validation_score to be better (less) than at least
# one of the last n_iter_no_change evaluations
if np.any(validation_loss + self.tol < loss_history):
loss_history[i % len(loss_history)] = validation_loss
else:
break

if self.dropout_rate > 0.:
self.scale_ = scale

Expand Down Expand Up @@ -985,6 +1081,20 @@ def fit(self, X, y, sample_weight=None, monitor=None):
if isinstance(self._loss, (CensoredSquaredLoss, IPCWLeastSquaresError)):
time = np.log(time)

if self.n_iter_no_change is not None:
X, X_val, event, event_val, time, time_val, sample_weight, sample_weight_val = train_test_split(
X,
event,
time,
sample_weight,
random_state=self.random_state,
test_size=self.validation_fraction,
stratify=event,
)
y_val = np.fromiter(zip(event_val, time_val), dtype=[('event', bool), ('time', np.float64)])
else:
X_val = y_val = sample_weight_val = None

self._init_state()
if sample_weight_is_none:
self.init_.fit(X, (event, time))
Expand All @@ -999,8 +1109,18 @@ def fit(self, X, y, sample_weight=None, monitor=None):

# fit the boosting stages
y = np.fromiter(zip(event, time), dtype=[('event', bool), ('time', np.float64)])
n_stages = self._fit_stages(X, y, raw_predictions, sample_weight, self._rng,
begin_at_stage, monitor)
n_stages = self._fit_stages(
X,
y,
raw_predictions,
sample_weight,
self._rng,
X_val,
y_val,
sample_weight_val,
begin_at_stage,
monitor,
)
# change shape of arrays after fit (early-stopping or additional tests)
if n_stages != self.estimators_.shape[0]:
self.estimators_ = self.estimators_[:n_stages]
Expand Down
24 changes: 23 additions & 1 deletion tests/test_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,40 @@ def test_ccp_alpha(self):
assert tree.node_count > subtree.node_count
assert tree.max_depth > subtree.max_depth

def test_early_stopping(self):
X, y = self.data

model = GradientBoostingSurvivalAnalysis(
n_estimators=1000, max_depth=2, n_iter_no_change=3, validation_fraction=0.2, random_state=0,
)
model.fit(X, y)

assert model.n_estimators_ == 36

@staticmethod
def test_negative_ccp_alpha(make_whas500):
whas500_data = make_whas500(with_std=False, to_numeric=True)

clf = GradientBoostingSurvivalAnalysis()
msg = "The 'ccp_alpha' parameter of GradientBoostingSurvivalAnalysis must be a float in the range " \
r"\[0\.0, inf\). Got -1\.0 instead\."
r"\[0\.0, inf\)\. Got -1\.0 instead\."

clf.set_params(ccp_alpha=-1.0)
with pytest.raises(ValueError, match=msg):
clf.fit(whas500_data.x, whas500_data.y)

@staticmethod
def test_negative_n_iter_no_change(make_whas500):
whas500_data = make_whas500(with_std=False, to_numeric=True)

clf = GradientBoostingSurvivalAnalysis()
msg = "The 'n_iter_no_change' parameter of GradientBoostingSurvivalAnalysis must be an int in the range " \
r"\[1, inf\) or None\. Got -1 instead\."

clf.set_params(n_iter_no_change=-1)
with pytest.raises(ValueError, match=msg):
clf.fit(whas500_data.x, whas500_data.y)

def test_fit_verbose(self):
self.assert_fit_and_predict(expected_cindex=None, n_estimators=10, verbose=1)

Expand Down

0 comments on commit 9c4c287

Please sign in to comment.