diff --git a/sksurv/svm/minlip.py b/sksurv/svm/minlip.py index c0676973..cae05363 100644 --- a/sksurv/svm/minlip.py +++ b/sksurv/svm/minlip.py @@ -58,7 +58,8 @@ def solve(self, P, q, G, h): # non of solved, solved inaccurate raise RuntimeError("OSQP solver failed: {}".format(results.info.status)) - return results.x[numpy.newaxis] + n_iter = results.info.iter + return results.x[numpy.newaxis], n_iter def _get_options(self): solver_opts = { @@ -134,7 +135,8 @@ def solve(self, P, q, G, h): # drop solution for t x = results["x"][1:] - return x[numpy.newaxis] + n_iter = results["info"]["iter"] + return x[numpy.newaxis], n_iter def _check_success(self, results): # pylint: disable=no-self-use exit_flag = results["info"]["exitFlag"] @@ -259,6 +261,9 @@ class MinlipSurvivalAnalysis(BaseEstimator, SurvivalAnalysisMixin): Names of features seen during ``fit``. Defined only when `X` has feature names that are all strings. + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + References ---------- .. [1] Van Belle, V., Pelckmans, K., Suykens, J. A. K., and Van Huffel, S. @@ -342,8 +347,9 @@ def _inner(): timer = timeit.Timer(_inner) self.timings_ = timer.repeat(self.timeit, number=1) - coef = solver.solve(**problem_data) + coef, n_iter = solver.solve(**problem_data) self._update_coef(coef, D) + self.n_iter_ = n_iter self.X_fit_ = x def _update_coef(self, coef, D): @@ -480,6 +486,9 @@ class HingeLossSurvivalSVM(MinlipSurvivalAnalysis): Names of features seen during ``fit``. Defined only when `X` has feature names that are all strings. + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + References ---------- .. [1] Van Belle, V., Pelckmans, K., Suykens, J. A., & Van Huffel, S. diff --git a/sksurv/svm/naive_survival_svm.py b/sksurv/svm/naive_survival_svm.py index d2efbab4..10e2a289 100644 --- a/sksurv/svm/naive_survival_svm.py +++ b/sksurv/svm/naive_survival_svm.py @@ -81,6 +81,11 @@ class NaiveSurvivalSVM(SurvivalAnalysisMixin, LinearSVC): max_iter : int, default: 1000 The maximum number of iterations to be run. + Attributes + ---------- + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + See also -------- sksurv.svm.FastSurvivalSVM diff --git a/sksurv/svm/survival_svm.py b/sksurv/svm/survival_svm.py index 4db24f1c..42b1cfbe 100644 --- a/sksurv/svm/survival_svm.py +++ b/sksurv/svm/survival_svm.py @@ -726,6 +726,10 @@ def fit(self, X, y): return self + @property + def n_iter_(self): + return self.optimizer_result_.nit + @staticmethod def _argsort_and_resolve_ties(time, random_state): """Like numpy.argsort, but resolves ties uniformly at random""" @@ -832,6 +836,9 @@ class FastSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin): Names of features seen during ``fit``. Defined only when `X` has feature names that are all strings. + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + See also -------- FastKernelSurvivalSVM @@ -971,6 +978,9 @@ class FastKernelSurvivalSVM(BaseSurvivalSVM, SurvivalAnalysisMixin): Names of features seen during ``fit``. Defined only when `X` has feature names that are all strings. + n_iter_ : int + Number of iterations run by the optimization routine to fit the model. + See also -------- FastSurvivalSVM diff --git a/tests/test_minlip.py b/tests/test_minlip.py index 543966ec..f614fe40 100644 --- a/tests/test_minlip.py +++ b/tests/test_minlip.py @@ -245,6 +245,7 @@ def test_toy_minlip_fit_osqp(self, toy_data): m.set_params(alpha=2) m.fit(x, y) + assert m.n_iter_ > 100 assert (1, x.shape[0]) == m.coef_.shape assert 1 == m.coef0 expected_coef = numpy.array([ @@ -355,6 +356,7 @@ def test_toy_minlip_fit_ecos(self, toy_data): m.set_params(alpha=2) m.fit(x, y) + assert m.n_iter_ > 10 assert (1, x.shape[0]) == m.coef_.shape assert 1 == m.coef0 expected_coef = numpy.array([ @@ -463,6 +465,8 @@ def test_breast_cancer_osqp(gbsg2): assert (1, x.shape[0]) == m.coef_.shape + assert m.n_iter_ > 1000 + p = m.predict(x) assert_cindex_almost_equal(y['cens'], y['time'], p, (0.5990741854033906, 79720, 53352, 0, 42)) @@ -527,6 +531,8 @@ def test_breast_cancer_ecos(gbsg2): assert (1, x.shape[0]) == m.coef_.shape + assert m.n_iter_ > 10 + p = m.predict(x) assert_cindex_almost_equal(y['cens'], y['time'], p, (0.5990741854033906, 79720, 53352, 0, 42)) diff --git a/tests/test_survival_svm.py b/tests/test_survival_svm.py index 47b9497c..adefa5eb 100644 --- a/tests/test_survival_svm.py +++ b/tests/test_survival_svm.py @@ -282,6 +282,7 @@ def test_default_optimizer(make_whas500): warnings.simplefilter("ignore", category=ConvergenceWarning) ssvm.fit(whas500.x, whas500.y) assert 'avltree' == ssvm.optimizer + assert 1 < ssvm.n_iter_ <= 25 @staticmethod @pytest.mark.slow() @@ -405,6 +406,7 @@ def test_default_optimizer(make_whas500): warnings.simplefilter("ignore", category=ConvergenceWarning) ssvm.fit(whas500.x, whas500.y) assert 'rbtree' == ssvm.optimizer + assert 1 < ssvm.n_iter_ <= 25 @staticmethod def test_unknown_optimizer(fake_data): @@ -776,6 +778,8 @@ def test_survival_squared_hinge_loss(whas500_without_ties): nrsvm = NaiveSurvivalSVM(loss='squared_hinge', dual=False, tol=8e-7, max_iter=1000, random_state=0) nrsvm.fit(x, y) + assert nrsvm.n_iter_ > 10 + rsvm = FastSurvivalSVM(optimizer='avltree', tol=8e-7, max_iter=1000, random_state=0) rsvm.fit(x, y)