From 619b4df474826dfaa9f0389092df6bbe3e08c4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=B6lsterl?= Date: Sun, 1 Oct 2023 10:34:31 +0200 Subject: [PATCH] Throw exception when event time is negative Fixes #406 --- sksurv/nonparametric.py | 8 +++++++- sksurv/svm/survival_svm.py | 3 --- sksurv/util.py | 14 +++++++++++++- tests/test_metrics.py | 2 +- tests/test_survival_function.py | 11 +++++++++++ 5 files changed, 32 insertions(+), 6 deletions(-) diff --git a/sksurv/nonparametric.py b/sksurv/nonparametric.py index 82ae7364..d1ff3240 100644 --- a/sksurv/nonparametric.py +++ b/sksurv/nonparametric.py @@ -284,7 +284,13 @@ def kaplan_meier_estimator( Survival Function Based on Transformations", Scandinavian Journal of Statistics. 1990;17(1):35–41. """ - event, time_enter, time_exit = check_y_survival(event, time_enter, time_exit, allow_all_censored=True) + event, time_enter, time_exit = check_y_survival( + event, + time_enter, + time_exit, + allow_all_censored=True, + allow_time_zero=reverse or time_enter is not None, + ) check_consistent_length(event, time_enter, time_exit) if conf_type is not None and reverse: diff --git a/sksurv/svm/survival_svm.py b/sksurv/svm/survival_svm.py index 0103f7a4..7d8b7ff2 100644 --- a/sksurv/svm/survival_svm.py +++ b/sksurv/svm/survival_svm.py @@ -758,9 +758,6 @@ def fit(self, X, y): if self.optimizer in {"simple", "PRSVM"}: raise ValueError(f"optimizer {self.optimizer!r} does not implement regression objective") - if (time <= 0).any(): - raise ValueError("observed time contains values smaller or equal to zero") - # log-transform time time = np.log(time) assert np.isfinite(time).all() diff --git a/sksurv/util.py b/sksurv/util.py index ba764aff..25305f6f 100644 --- a/sksurv/util.py +++ b/sksurv/util.py @@ -96,7 +96,7 @@ def from_dataframe(event, time, data): ) -def check_y_survival(y_or_event, *args, allow_all_censored=False): +def check_y_survival(y_or_event, *args, allow_all_censored=False, allow_time_zero=False): """Check that array correctly represents an outcome for survival analysis. Parameters @@ -114,6 +114,9 @@ def check_y_survival(y_or_event, *args, allow_all_censored=False): allow_all_censored : bool, optional, default: False Whether to allow all events to be censored. + allow_time_zero : bool, optional, default: False + Whether to allow event times to be zero. + Returns ------- event : array, shape=[n_samples,], dtype=bool @@ -156,6 +159,15 @@ def check_y_survival(y_or_event, *args, allow_all_censored=False): if not np.issubdtype(yt.dtype, np.number): raise ValueError(f"time must be numeric, but found {yt.dtype} for argument {i + 2}") + if allow_time_zero: + cond = yt < 0 + msg = "observed time contains values smaller zero" + else: + cond = yt <= 0 + msg = "observed time contains values smaller or equal to zero" + if np.any(cond): + raise ValueError(msg) + return_val.append(yt) return tuple(return_val) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ec54a857..85169173 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1075,7 +1075,7 @@ def test_brier_coxph(): def test_brier_score_int_dtype(): - times = np.arange(30, dtype=int) + times = np.arange(1, 31, dtype=int) rnd = np.random.RandomState(1) times = rnd.choice(times, 20) diff --git a/tests/test_survival_function.py b/tests/test_survival_function.py index 5449700a..2d08d71c 100644 --- a/tests/test_survival_function.py +++ b/tests/test_survival_function.py @@ -4,6 +4,7 @@ from sksurv.linear_model import CoxnetSurvivalAnalysis from sksurv.testing import all_survival_estimators +from sksurv.util import Surv def all_survival_function_estimators(): @@ -30,3 +31,13 @@ def test_survival_functions(estimator, make_whas500): arr = np.row_stack([fn(times) for fn in fns_cls]) assert_array_almost_equal(arr, fns_arr) + + +@pytest.mark.parametrize("estimator", all_survival_function_estimators()) +@pytest.mark.parametrize("y_time", [-1e-8, -1, np.finfo(float).min]) +def test_fit_negative_survial_time_raises(estimator, y_time): + X = np.random.randn(7, 3) + y = Surv.from_arrays(event=np.ones(7, dtype=bool), time=[1, 9, 3, y_time, 1, 8, 1e10]) + + with pytest.raises(ValueError, match="observed time contains values smaller or equal to zero"): + estimator.fit(X, y)