Skip to content

Commit

Permalink
ENH: Add n_iter_ attribute to Survival SVM models
Browse files Browse the repository at this point in the history
Fixes #277
  • Loading branch information
sebp committed Aug 13, 2022
1 parent eb351d3 commit 74c3483
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 3 deletions.
15 changes: 12 additions & 3 deletions sksurv/svm/minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def solve(self, P, q, G, h):
# non of solved, solved inaccurate
raise RuntimeError("OSQP solver failed: {}".format(results.info.status))

return results.x[numpy.newaxis]
n_iter = results.info.iter
return results.x[numpy.newaxis], n_iter

def _get_options(self):
solver_opts = {
Expand Down Expand Up @@ -134,7 +135,8 @@ def solve(self, P, q, G, h):

# drop solution for t
x = results["x"][1:]
return x[numpy.newaxis]
n_iter = results["info"]["iter"]
return x[numpy.newaxis], n_iter

def _check_success(self, results): # pylint: disable=no-self-use
exit_flag = results["info"]["exitFlag"]
Expand Down Expand Up @@ -259,6 +261,9 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin):
Names of features seen during ``fit``. Defined only when `X`
has feature names that are all strings.
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
References
----------
.. [1] Van Belle, V., Pelckmans, K., Suykens, J. A. K., and Van Huffel, S.
Expand Down Expand Up @@ -342,8 +347,9 @@ def _inner():
timer = timeit.Timer(_inner)
self.timings_ = timer.repeat(self.timeit, number=1)

coef = solver.solve(**problem_data)
coef, n_iter = solver.solve(**problem_data)
self._update_coef(coef, D)
self.n_iter_ = n_iter
self.X_fit_ = x

def _update_coef(self, coef, D):
Expand Down Expand Up @@ -480,6 +486,9 @@ class HingeLossSurvivalSVM(MinlipSurvivalAnalysis):
Names of features seen during ``fit``. Defined only when `X`
has feature names that are all strings.
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
References
----------
.. [1] Van Belle, V., Pelckmans, K., Suykens, J. A., & Van Huffel, S.
Expand Down
5 changes: 5 additions & 0 deletions sksurv/svm/naive_survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class NaiveSurvivalSVM(SurvivalAnalysisMixin, LinearSVC):
max_iter : int, default: 1000
The maximum number of iterations to be run.
Attributes
----------
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
See also
--------
sksurv.svm.FastSurvivalSVM
Expand Down
10 changes: 10 additions & 0 deletions sksurv/svm/survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,10 @@ def fit(self, X, y):

return self

@property
def n_iter_(self):
return self.optimizer_result_.nit

@staticmethod
def _argsort_and_resolve_ties(time, random_state):
"""Like numpy.argsort, but resolves ties uniformly at random"""
Expand Down Expand Up @@ -832,6 +836,9 @@ class FastSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
Names of features seen during ``fit``. Defined only when `X`
has feature names that are all strings.
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
See also
--------
FastKernelSurvivalSVM
Expand Down Expand Up @@ -971,6 +978,9 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin):
Names of features seen during ``fit``. Defined only when `X`
has feature names that are all strings.
n_iter_ : int
Number of iterations run by the optimization routine to fit the model.
See also
--------
FastSurvivalSVM
Expand Down
6 changes: 6 additions & 0 deletions tests/test_minlip.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def test_toy_minlip_fit_osqp(self, toy_data):
m.set_params(alpha=2)
m.fit(x, y)

assert m.n_iter_ > 100
assert (1, x.shape[0]) == m.coef_.shape
assert 1 == m.coef0
expected_coef = numpy.array([
Expand Down Expand Up @@ -355,6 +356,7 @@ def test_toy_minlip_fit_ecos(self, toy_data):
m.set_params(alpha=2)
m.fit(x, y)

assert m.n_iter_ > 10
assert (1, x.shape[0]) == m.coef_.shape
assert 1 == m.coef0
expected_coef = numpy.array([
Expand Down Expand Up @@ -463,6 +465,8 @@ def test_breast_cancer_osqp(gbsg2):

assert (1, x.shape[0]) == m.coef_.shape

assert m.n_iter_ > 1000

p = m.predict(x)
assert_cindex_almost_equal(y['cens'], y['time'], p,
(0.5990741854033906, 79720, 53352, 0, 42))
Expand Down Expand Up @@ -527,6 +531,8 @@ def test_breast_cancer_ecos(gbsg2):

assert (1, x.shape[0]) == m.coef_.shape

assert m.n_iter_ > 10

p = m.predict(x)
assert_cindex_almost_equal(y['cens'], y['time'], p,
(0.5990741854033906, 79720, 53352, 0, 42))
Expand Down
4 changes: 4 additions & 0 deletions tests/test_survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def test_default_optimizer(make_whas500):
warnings.simplefilter("ignore", category=ConvergenceWarning)
ssvm.fit(whas500.x, whas500.y)
assert 'avltree' == ssvm.optimizer
assert 1 < ssvm.n_iter_ <= 25

@staticmethod
@pytest.mark.slow()
Expand Down Expand Up @@ -405,6 +406,7 @@ def test_default_optimizer(make_whas500):
warnings.simplefilter("ignore", category=ConvergenceWarning)
ssvm.fit(whas500.x, whas500.y)
assert 'rbtree' == ssvm.optimizer
assert 1 < ssvm.n_iter_ <= 25

@staticmethod
def test_unknown_optimizer(fake_data):
Expand Down Expand Up @@ -776,6 +778,8 @@ def test_survival_squared_hinge_loss(whas500_without_ties):
nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0)
nrsvm.fit(x, y)

assert nrsvm.n_iter_ > 10

rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0)
rsvm.fit(x, y)

Expand Down

0 comments on commit 74c3483

Please sign in to comment.