Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce SurvivalTree.predict's memory use #369

Merged
merged 24 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
4c9505f
reduce SurvivalTree.predict's memory use
cpoerschke Jun 8, 2023
a227853
add low_memory=False option to SurvivalTree constructor
cpoerschke Jun 8, 2023
2fa63cb
add test_predict_low_memory in test_tree.py
cpoerschke Jun 8, 2023
a1a1213
development increment: j_delta=2 (tests continue to pass)
cpoerschke Jun 9, 2023
8c62e09
development increment: j_delta=3 (tests fail for some reason)
cpoerschke Jun 9, 2023
67969a7
low memory mode changes
cpoerschke Jun 12, 2023
f1639ff
Merge remote-tracking branch 'origin/master' into pr-6
cpoerschke Jun 12, 2023
b85f924
annotate TODO w.r.t. summing only for event times
cpoerschke Jun 12, 2023
b5d1ac0
address CI feedback
cpoerschke Jun 12, 2023
179d3ef
action review feedback (part 1 of 2)
cpoerschke Jun 13, 2023
9065fae
action review feedback (part 2 of 2)
cpoerschke Jun 13, 2023
3af23c5
int[::1] --> const bint[::1] for LogrankCriterion's is_event_time
cpoerschke Jun 13, 2023
fe5aed5
lint: line-too-long
cpoerschke Jun 13, 2023
4c21964
address CI feedback (part 1 of 2)
cpoerschke Jun 13, 2023
6ef7353
address CI feedback (part 2 of 2)
cpoerschke Jun 13, 2023
bdf7162
action CI feedback
cpoerschke Jun 13, 2023
706352b
Assign self.is_event_time to local variable
sebp Jun 17, 2023
2b81bb3
Add low_memory option to forest classes
sebp Jun 17, 2023
9acd236
Add test case for low-memory mode for forests
sebp Jun 17, 2023
a41f022
Remove test_predict_low_memory
sebp Jun 17, 2023
4440a76
Use type cnp.npy_bool instead of bint
sebp Jun 17, 2023
35b0811
Remove type conversion
sebp Jun 17, 2023
6cbb8fa
Fix code format
sebp Jun 17, 2023
9e1b344
Fix API doc of RandomSurvivalForest
sebp Jun 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sksurv/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ class RandomSurvivalForest(_BaseSurvivalForest):
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0.0, 1.0]`.

low_memory : boolean, default: False
If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
and ``predict_survival_function`` are not implemented.

Attributes
----------
estimators_ : list of SurvivalTree instances
Expand Down Expand Up @@ -455,6 +459,7 @@ def __init__(
verbose=0,
warm_start=False,
max_samples=None,
low_memory=False,
):
super().__init__(
estimator=SurvivalTree(),
Expand All @@ -467,6 +472,7 @@ def __init__(
"max_features",
"max_leaf_nodes",
"random_state",
"low_memory",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -483,6 +489,7 @@ def __init__(
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.low_memory = low_memory

def predict_cumulative_hazard_function(self, X, return_array=False):
"""Predict cumulative hazard function.
Expand Down Expand Up @@ -714,6 +721,10 @@ class ExtraSurvivalTrees(_BaseSurvivalForest):
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
`max_samples` should be in the interval `(0.0, 1.0]`.

low_memory : boolean, default: False
If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
and ``predict_survival_function`` are not implemented.

Attributes
----------
estimators_ : list of SurvivalTree instances
Expand Down Expand Up @@ -762,6 +773,7 @@ def __init__(
verbose=0,
warm_start=False,
max_samples=None,
low_memory=False,
):
super().__init__(
estimator=SurvivalTree(splitter="random"),
Expand All @@ -774,6 +786,7 @@ def __init__(
"max_features",
"max_leaf_nodes",
"random_state",
"low_memory",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -790,6 +803,7 @@ def __init__(
self.min_weight_fraction_leaf = min_weight_fraction_leaf
self.max_features = max_features
self.max_leaf_nodes = max_leaf_nodes
self.low_memory = low_memory

def predict_cumulative_hazard_function(self, X, return_array=False):
"""Predict cumulative hazard function.
Expand Down
51 changes: 33 additions & 18 deletions sksurv/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ cdef class LogrankCriterion(Criterion):
cdef:
# unique time points sorted in ascending order
const DOUBLE_t[::1] unique_times
const cnp.npy_bool[::1] is_event_time
SIZE_t n_unique_times
size_t nbytes
RisksetCounter riskset_total
Expand All @@ -139,7 +140,7 @@ cdef class LogrankCriterion(Criterion):
SIZE_t * samples_time_idx
SIZE_t n_samples_left

def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, const DOUBLE_t[::1] unique_times):
def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, const DOUBLE_t[::1] unique_times, const cnp.npy_bool[::1] is_event_time):
# Default values
self.samples = NULL
self.start = 0
Expand All @@ -149,6 +150,7 @@ cdef class LogrankCriterion(Criterion):
self.n_outputs = n_outputs
self.n_samples = n_samples
self.unique_times = unique_times
self.is_event_time = is_event_time
self.n_unique_times = unique_times.shape[0]
self.nbytes = self.n_unique_times * sizeof(cnp.npy_int64)
self.n_node_samples = 0
Expand All @@ -168,7 +170,7 @@ cdef class LogrankCriterion(Criterion):
free(self.samples_time_idx)

def __reduce__(self):
return (type(self), (self.n_outputs, self.n_samples, self.unique_times), self.__getstate__())
return (type(self), (self.n_outputs, self.n_samples, self.unique_times, self.is_event_time), self.__getstate__())

cdef int init(self, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight,
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
Expand Down Expand Up @@ -323,24 +325,37 @@ cdef class LogrankCriterion(Criterion):
"""Compute the node value of samples[start:end] into dest."""
# Estimate cumulative hazard function
cdef:
const cnp.npy_bool[::1] is_event_time = self.is_event_time
SIZE_t i
SIZE_t j
DOUBLE_t ratio
DOUBLE_t n_events
DOUBLE_t n_at_risk

self.riskset_total.at(0, &n_at_risk, &n_events)
ratio = n_events / n_at_risk
dest[0] = ratio # Nelson-Aalen estimator
dest[1] = 1.0 - ratio # Kaplan-Meier estimator

j = 2
for i in range(1, self.n_unique_times):
self.riskset_total.at(i, &n_at_risk, &n_events)
dest[j] = dest[j - 2]
dest[j + 1] = dest[j - 1]
if n_at_risk != 0:
ratio = n_events / n_at_risk
dest[j] += ratio
dest[j + 1] *= 1.0 - ratio
j += 2
DOUBLE_t dest_j0

# low memory mode
if self.n_outputs == 1:
dest[0] = dest_j0 = 0
for i in range(0, self.n_unique_times):
self.riskset_total.at(i, &n_at_risk, &n_events)
if n_at_risk != 0:
ratio = n_events / n_at_risk
dest_j0 += ratio
if is_event_time[i]:
dest[0] += dest_j0
else:
self.riskset_total.at(0, &n_at_risk, &n_events)
ratio = n_events / n_at_risk
dest[0] = ratio # Nelson-Aalen estimator
dest[1] = 1.0 - ratio # Kaplan-Meier estimator

j = 2
for i in range(1, self.n_unique_times):
self.riskset_total.at(i, &n_at_risk, &n_events)
dest[j] = dest[j - 2]
dest[j + 1] = dest[j - 1]
if n_at_risk != 0:
ratio = n_events / n_at_risk
dest[j] += ratio
dest[j + 1] *= 1.0 - ratio
j += 2
37 changes: 33 additions & 4 deletions sksurv/tree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
Best nodes are defined as relative reduction in impurity.
If None then unlimited number of leaf nodes.

low_memory : boolean, default: False
If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function``
and ``predict_survival_function`` are not implemented.

Attributes
----------
unique_times_ : array of shape = (n_unique_times,)
Expand Down Expand Up @@ -162,6 +166,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin):
],
"random_state": ["random_state"],
"max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None],
"low_memory": ["boolean"],
}

def __init__(
Expand All @@ -175,6 +180,7 @@ def __init__(
max_features=None,
random_state=None,
max_leaf_nodes=None,
low_memory=False,
):
self.splitter = splitter
self.max_depth = max_depth
Expand All @@ -184,6 +190,7 @@ def __init__(
self.max_features = max_features
self.random_state = random_state
self.max_leaf_nodes = max_leaf_nodes
self.low_memory = low_memory

def fit(self, X, y, sample_weight=None, check_input=True):
"""Build a survival tree from the training set (X, y).
Expand Down Expand Up @@ -225,13 +232,18 @@ def fit(self, X, y, sample_weight=None, check_input=True):
n_samples, self.n_features_in_ = X.shape
params = self._check_params(n_samples)

self.n_outputs_ = self.unique_times_.shape[0]
# one "class" for CHF, one for survival function
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
if self.low_memory:
sebp marked this conversation as resolved.
Show resolved Hide resolved
self.n_outputs_ = 1
# one "class" only, for the sum over the CHF
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp)
else:
self.n_outputs_ = self.unique_times_.shape[0]
# one "class" for CHF, one for survival function
self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2
sebp marked this conversation as resolved.
Show resolved Hide resolved

# Build tree
self.criterion = "logrank"
criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_)
criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_, self.is_event_time_)

SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS

Expand Down Expand Up @@ -326,6 +338,14 @@ def _check_max_features(self):

self.max_features_ = max_features

def _check_low_memory(self, function):
"""Check if `function` is supported in low memory mode and throw if it is not."""
if self.low_memory:
raise NotImplementedError(
f"{function} is not implemented in low memory mode."
+ " run fit with low_memory=False to disable low memory mode."
)

def _validate_X_predict(self, X, check_input, accept_sparse="csr"):
"""Validate X whenever one tries to predict"""
if check_input:
Expand Down Expand Up @@ -364,6 +384,13 @@ def predict(self, X, check_input=True):
risk_scores : ndarray, shape = (n_samples,)
Predicted risk scores.
"""

if self.low_memory:
check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")
pred = self.tree_.predict(X)
return pred[..., 0]

chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True)
return chf[:, self.is_event_time_].sum(1)

Expand Down Expand Up @@ -424,6 +451,7 @@ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=F
>>> plt.ylim(0, 1)
>>> plt.show()
"""
self._check_low_memory("predict_cumulative_hazard_function")
check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")

Expand Down Expand Up @@ -491,6 +519,7 @@ def predict_survival_function(self, X, check_input=True, return_array=False):
>>> plt.ylim(0, 1)
>>> plt.show()
"""
self._check_low_memory("predict_survival_function")
check_is_fitted(self, "tree_")
X = self._validate_X_predict(X, check_input, accept_sparse="csr")

Expand Down
45 changes: 45 additions & 0 deletions tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sksurv.ensemble import ExtraSurvivalTrees, RandomSurvivalForest
from sksurv.preprocessing import OneHotEncoder
from sksurv.testing import assert_cindex_almost_equal
from sksurv.tree import SurvivalTree

FORESTS = [
RandomSurvivalForest,
Expand Down Expand Up @@ -345,3 +346,47 @@ def test_predict_sparse(make_whas500, forest_cls):
assert_array_equal(y_pred, y_pred_csr)
assert_array_equal(y_cum_h_csr, y_cum_h)
assert_array_equal(y_surv, y_surv_csr)


@pytest.mark.parametrize(
"est_cls,params",
[
(SurvivalTree, {"min_samples_leaf": 10, "random_state": 42}),
(RandomSurvivalForest, {"n_estimators": 10, "min_samples_leaf": 10, "random_state": 42}),
(ExtraSurvivalTrees, {"n_estimators": 10, "min_samples_leaf": 10, "random_state": 42}),
],
)
def test_predict_low_memory(make_whas500, est_cls, params):
whas500 = make_whas500(to_numeric=True)
X, y = whas500.x, whas500.y

X_train, X_test, y_train, _ = train_test_split(X, y, random_state=params["random_state"])

est_high = est_cls(**params)
est_high.set_params(low_memory=False)
est_high.fit(X_train, y_train)
pred_high = est_high.predict(X_test)

est_low = est_cls(**params)
est_low.set_params(low_memory=True)
est_low.fit(X_train, y_train)
pred_low = est_low.predict(X_test)

assert pred_high.shape[0] == X_test.shape[0]
assert pred_low.shape[0] == X_test.shape[0]

assert_array_almost_equal(pred_high, pred_low)

msg = (
"predict_cumulative_hazard_function is not implemented in low memory mode."
" run fit with low_memory=False to disable low memory mode."
)
with pytest.raises(NotImplementedError, match=msg):
est_low.predict_cumulative_hazard_function(X_test)

msg = (
"predict_survival_function is not implemented in low memory mode."
" run fit with low_memory=False to disable low memory mode."
)
with pytest.raises(NotImplementedError, match=msg):
est_low.predict_survival_function(X_test)