diff --git a/sksurv/metrics.py b/sksurv/metrics.py index a8fdf128..6ddc507a 100644 --- a/sksurv/metrics.py +++ b/sksurv/metrics.py @@ -63,7 +63,7 @@ def _check_inputs(event_indicator, event_time, estimate): def _check_times(test_time, times): - times = check_array(np.atleast_1d(times), ensure_2d=False, dtype=test_time.dtype, input_name="times") + times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times") times = np.unique(times) if times.max() >= test_time.max() or times.min() < test_time.min(): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index e330e4e0..c1ab828b 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -963,6 +963,28 @@ def test_brier_coxph(): assert round(abs(score[0] - 0.208817407492645), 5) == 0 +def test_brier_score_int_dtype(): + times = np.arange(30, dtype=int) + rnd = np.random.RandomState(1) + times = rnd.choice(times, 20) + + y_int = np.empty(20, dtype=[("event", bool), ("time", int)]) + y_int["event"] = np.ones(20, dtype=bool) + y_int["event"][:10] = False + y_int["time"] = times + + pred = rnd.randn(20, 10) + tp = np.linspace(1.0, 2.0, 10) + _, bs_int = brier_score(y_int, y_int, pred, times=tp) + + y_float = np.empty(20, dtype=[("event", bool), ("time", float)]) + y_float["event"][:] = y_int["event"] + y_float["time"][:] = y_int["time"] + _, bs_float = brier_score(y_float, y_float, pred, times=tp) + + assert_array_almost_equal(bs_float, bs_int) + + def test_ibs_nottingham_1(nottingham_prognostic_index): times = np.linspace(365, 1825, 5) # t=1..5 years preds, y = nottingham_prognostic_index(times)