Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce concordance_index_censored's memory use #362

Merged
merged 6 commits into from
May 12, 2023
33 changes: 19 additions & 14 deletions sksurv/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def _check_estimate_2d(estimate, test_time, time_points, estimator):
def _get_comparable(event_indicator, event_time, order):
sebp marked this conversation as resolved.
Show resolved Hide resolved
n_samples = len(event_time)
tied_time = 0
comparable = {}
i = 0
while i < n_samples - 1:
time_i = event_time[order[i]]
Expand All @@ -103,32 +102,34 @@ def _get_comparable(event_indicator, event_time, order):
censored_at_same_time = ~event_at_same_time
for j in range(i, end):
if event_indicator[order[j]]:
mask = np.zeros(n_samples, dtype=bool)
mask[end:] = True
# an event is comparable to censored samples at same time point
mask[i:end] = censored_at_same_time
comparable[j] = mask
mask_info = (i, end, censored_at_same_time)
tied_time += censored_at_same_time.sum()
yield (j, mask_info, tied_time)
sebp marked this conversation as resolved.
Show resolved Hide resolved
i = end

return comparable, tied_time


def _estimate_concordance_index(event_indicator, event_time, estimate, weights, tied_tol=1e-8):
n_samples = len(event_time)
order = np.argsort(event_time)

comparable, tied_time = _get_comparable(event_indicator, event_time, order)

if len(comparable) == 0:
raise NoComparablePairException(
"Data has no comparable pairs, cannot estimate concordance index.")
tied_time = None

concordant = 0
discordant = 0
tied_risk = 0
numerator = 0.0
denominator = 0.0
for ind, mask in comparable.items():
for (ind, mask_info, tied_time) in _get_comparable(event_indicator, event_time, order):
# mask info in three parts
i = mask_info[0]
end = mask_info[1]
censored_at_same_time = mask_info[2]
# construct (potentially large) mask from (smaller) mask info
mask = np.zeros(n_samples, dtype=bool)
mask[end:] = True
# an event is comparable to censored samples at same time point
mask[i:end] = censored_at_same_time

est_i = estimate[order[ind]]
event_i = event_indicator[order[ind]]
w_i = weights[order[ind]]
Expand All @@ -150,6 +151,10 @@ def _estimate_concordance_index(event_indicator, event_time, estimate, weights,
concordant += n_con
discordant += est.size - n_con - n_ties

if tied_time is None:
raise NoComparablePairException(
"Data has no comparable pairs, cannot estimate concordance index.")

cindex = numerator / denominator
return cindex, concordant, discordant, tied_risk, tied_time

Expand Down