Skip to content

Commit

Permalink
Merge pull request #349 from sebp/fix/317
Browse files Browse the repository at this point in the history
FIX: Downcast time points passed to brier_score
  • Loading branch information
sebp committed Apr 2, 2023
2 parents d82e9be + 93d0c24 commit 139fd84
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sksurv/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _check_inputs(event_indicator, event_time, estimate):


def _check_times(test_time, times):
times = check_array(np.atleast_1d(times), ensure_2d=False, dtype=test_time.dtype, input_name="times")
times = check_array(np.atleast_1d(times), ensure_2d=False, input_name="times")
times = np.unique(times)

if times.max() >= test_time.max() or times.min() < test_time.min():
Expand Down
22 changes: 22 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,28 @@ def test_brier_coxph():
assert round(abs(score[0] - 0.208817407492645), 5) == 0


def test_brier_score_int_dtype():
times = np.arange(30, dtype=int)
rnd = np.random.RandomState(1)
times = rnd.choice(times, 20)

y_int = np.empty(20, dtype=[("event", bool), ("time", int)])
y_int["event"] = np.ones(20, dtype=bool)
y_int["event"][:10] = False
y_int["time"] = times

pred = rnd.randn(20, 10)
tp = np.linspace(1.0, 2.0, 10)
_, bs_int = brier_score(y_int, y_int, pred, times=tp)

y_float = np.empty(20, dtype=[("event", bool), ("time", float)])
y_float["event"][:] = y_int["event"]
y_float["time"][:] = y_int["time"]
_, bs_float = brier_score(y_float, y_float, pred, times=tp)

assert_array_almost_equal(bs_float, bs_int)


def test_ibs_nottingham_1(nottingham_prognostic_index):
times = np.linspace(365, 1825, 5) # t=1..5 years
preds, y = nottingham_prognostic_index(times)
Expand Down

0 comments on commit 139fd84

Please sign in to comment.