Skip to content

Commit

Permalink
Throw exception when event time is negative
Browse files Browse the repository at this point in the history
Fixes #406
  • Loading branch information
sebp committed Oct 1, 2023
1 parent 8a22157 commit 88092b2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 9 deletions.
5 changes: 1 addition & 4 deletions sksurv/svm/survival_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ def fit(self, X, y):
self
"""
X = self._validate_for_fit(X)
event, time = check_array_survival(X, y)
event, time = check_array_survival(X, y, allow_time_zero=False)

self._validate_params()

Expand All @@ -758,9 +758,6 @@ def fit(self, X, y):
if self.optimizer in {"simple", "PRSVM"}:
raise ValueError(f"optimizer {self.optimizer!r} does not implement regression objective")

if (time <= 0).any():
raise ValueError("observed time contains values smaller or equal to zero")

# log-transform time
time = np.log(time)
assert np.isfinite(time).all()
Expand Down
20 changes: 16 additions & 4 deletions sksurv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def from_dataframe(event, time, data):
)


def check_y_survival(y_or_event, *args, allow_all_censored=False):
def check_y_survival(y_or_event, *args, allow_all_censored=False, allow_time_zero=True):
"""Check that array correctly represents an outcome for survival analysis.
Parameters
Expand All @@ -114,6 +114,9 @@ def check_y_survival(y_or_event, *args, allow_all_censored=False):
allow_all_censored : bool, optional, default: False
Whether to allow all events to be censored.
allow_time_zero : bool, optional, default: True
Whether to allow event times to be zero.
Returns
-------
event : array, shape=[n_samples,], dtype=bool
Expand Down Expand Up @@ -156,12 +159,21 @@ def check_y_survival(y_or_event, *args, allow_all_censored=False):
if not np.issubdtype(yt.dtype, np.number):
raise ValueError(f"time must be numeric, but found {yt.dtype} for argument {i + 2}")

if allow_time_zero:
cond = yt < 0
msg = "observed time contains values smaller zero"
else:
cond = yt <= 0
msg = "observed time contains values smaller or equal to zero"
if np.any(cond):
raise ValueError(msg)

return_val.append(yt)

return tuple(return_val)


def check_array_survival(X, y):
def check_array_survival(X, y, **kwargs):
"""Check that all arrays have consistent first dimensions.
Parameters
Expand All @@ -175,7 +187,7 @@ def check_array_survival(X, y):
second field.
kwargs : dict
Additional arguments passed to :func:`sklearn.utils.check_array`.
Additional arguments passed to :func:`check_y_survival`.
Returns
-------
Expand All @@ -185,7 +197,7 @@ def check_array_survival(X, y):
time : array, shape=[n_samples,], dtype=float
Time of event or censoring.
"""
event, time = check_y_survival(y)
event, time = check_y_survival(y, **kwargs)
check_consistent_length(X, event, time)
return event, time

Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def test_brier_coxph():


def test_brier_score_int_dtype():
times = np.arange(30, dtype=int)
times = np.arange(1, 31, dtype=int)
rnd = np.random.RandomState(1)
times = rnd.choice(times, 20)

Expand Down
11 changes: 11 additions & 0 deletions tests/test_survival_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sksurv.linear_model import CoxnetSurvivalAnalysis
from sksurv.testing import all_survival_estimators
from sksurv.util import Surv


def all_survival_function_estimators():
Expand All @@ -30,3 +31,13 @@ def test_survival_functions(estimator, make_whas500):
arr = np.row_stack([fn(times) for fn in fns_cls])

assert_array_almost_equal(arr, fns_arr)


@pytest.mark.parametrize("estimator", all_survival_function_estimators())
@pytest.mark.parametrize("y_time", [-1e-8, -1, np.finfo(float).min])
def test_fit_negative_survial_time_raises(estimator, y_time):
X = np.random.randn(7, 3)
y = Surv.from_arrays(event=np.ones(7, dtype=bool), time=[1, 9, 3, y_time, 1, 8, 1e10])

with pytest.raises(ValueError, match="observed time contains values smaller zero"):
estimator.fit(X, y)

0 comments on commit 88092b2

Please sign in to comment.