diff --git a/sksurv/ensemble/forest.py b/sksurv/ensemble/forest.py index 9195b6c2..5deba39a 100644 --- a/sksurv/ensemble/forest.py +++ b/sksurv/ensemble/forest.py @@ -376,6 +376,10 @@ class RandomSurvivalForest(_BaseSurvivalForest): - If float, then draw `max_samples * X.shape[0]` samples. Thus, `max_samples` should be in the interval `(0.0, 1.0]`. + low_memory : boolean, default: False + If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function`` + and ``predict_survival_function`` are not implemented. + Attributes ---------- estimators_ : list of SurvivalTree instances @@ -455,6 +459,7 @@ def __init__( verbose=0, warm_start=False, max_samples=None, + low_memory=False, ): super().__init__( estimator=SurvivalTree(), @@ -467,6 +472,7 @@ def __init__( "max_features", "max_leaf_nodes", "random_state", + "low_memory", ), bootstrap=bootstrap, oob_score=oob_score, @@ -483,6 +489,7 @@ def __init__( self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.low_memory = low_memory def predict_cumulative_hazard_function(self, X, return_array=False): """Predict cumulative hazard function. @@ -714,6 +721,10 @@ class ExtraSurvivalTrees(_BaseSurvivalForest): - If float, then draw `max_samples * X.shape[0]` samples. Thus, `max_samples` should be in the interval `(0.0, 1.0]`. + low_memory : boolean, default: False + If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function`` + and ``predict_survival_function`` are not implemented. + Attributes ---------- estimators_ : list of SurvivalTree instances @@ -762,6 +773,7 @@ def __init__( verbose=0, warm_start=False, max_samples=None, + low_memory=False, ): super().__init__( estimator=SurvivalTree(splitter="random"), @@ -774,6 +786,7 @@ def __init__( "max_features", "max_leaf_nodes", "random_state", + "low_memory", ), bootstrap=bootstrap, oob_score=oob_score, @@ -790,6 +803,7 @@ def __init__( self.min_weight_fraction_leaf = min_weight_fraction_leaf self.max_features = max_features self.max_leaf_nodes = max_leaf_nodes + self.low_memory = low_memory def predict_cumulative_hazard_function(self, X, return_array=False): """Predict cumulative hazard function. diff --git a/sksurv/tree/_criterion.pyx b/sksurv/tree/_criterion.pyx index f756fc44..3c0ff8cc 100644 --- a/sksurv/tree/_criterion.pyx +++ b/sksurv/tree/_criterion.pyx @@ -131,6 +131,7 @@ cdef class LogrankCriterion(Criterion): cdef: # unique time points sorted in ascending order const DOUBLE_t[::1] unique_times + const cnp.npy_bool[::1] is_event_time SIZE_t n_unique_times size_t nbytes RisksetCounter riskset_total @@ -139,7 +140,7 @@ cdef class LogrankCriterion(Criterion): SIZE_t * samples_time_idx SIZE_t n_samples_left - def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, const DOUBLE_t[::1] unique_times): + def __cinit__(self, SIZE_t n_outputs, SIZE_t n_samples, const DOUBLE_t[::1] unique_times, const cnp.npy_bool[::1] is_event_time): # Default values self.samples = NULL self.start = 0 @@ -149,6 +150,7 @@ cdef class LogrankCriterion(Criterion): self.n_outputs = n_outputs self.n_samples = n_samples self.unique_times = unique_times + self.is_event_time = is_event_time self.n_unique_times = unique_times.shape[0] self.nbytes = self.n_unique_times * sizeof(cnp.npy_int64) self.n_node_samples = 0 @@ -168,7 +170,7 @@ cdef class LogrankCriterion(Criterion): free(self.samples_time_idx) def __reduce__(self): - return (type(self), (self.n_outputs, self.n_samples, self.unique_times), self.__getstate__()) + return (type(self), (self.n_outputs, self.n_samples, self.unique_times, self.is_event_time), self.__getstate__()) cdef int init(self, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, @@ -323,24 +325,37 @@ cdef class LogrankCriterion(Criterion): """Compute the node value of samples[start:end] into dest.""" # Estimate cumulative hazard function cdef: + const cnp.npy_bool[::1] is_event_time = self.is_event_time SIZE_t i SIZE_t j DOUBLE_t ratio DOUBLE_t n_events DOUBLE_t n_at_risk - - self.riskset_total.at(0, &n_at_risk, &n_events) - ratio = n_events / n_at_risk - dest[0] = ratio # Nelson-Aalen estimator - dest[1] = 1.0 - ratio # Kaplan-Meier estimator - - j = 2 - for i in range(1, self.n_unique_times): - self.riskset_total.at(i, &n_at_risk, &n_events) - dest[j] = dest[j - 2] - dest[j + 1] = dest[j - 1] - if n_at_risk != 0: - ratio = n_events / n_at_risk - dest[j] += ratio - dest[j + 1] *= 1.0 - ratio - j += 2 + DOUBLE_t dest_j0 + + # low memory mode + if self.n_outputs == 1: + dest[0] = dest_j0 = 0 + for i in range(0, self.n_unique_times): + self.riskset_total.at(i, &n_at_risk, &n_events) + if n_at_risk != 0: + ratio = n_events / n_at_risk + dest_j0 += ratio + if is_event_time[i]: + dest[0] += dest_j0 + else: + self.riskset_total.at(0, &n_at_risk, &n_events) + ratio = n_events / n_at_risk + dest[0] = ratio # Nelson-Aalen estimator + dest[1] = 1.0 - ratio # Kaplan-Meier estimator + + j = 2 + for i in range(1, self.n_unique_times): + self.riskset_total.at(i, &n_at_risk, &n_events) + dest[j] = dest[j - 2] + dest[j + 1] = dest[j - 1] + if n_at_risk != 0: + ratio = n_events / n_at_risk + dest[j] += ratio + dest[j + 1] *= 1.0 - ratio + j += 2 diff --git a/sksurv/tree/tree.py b/sksurv/tree/tree.py index 45db8677..0f797e4d 100644 --- a/sksurv/tree/tree.py +++ b/sksurv/tree/tree.py @@ -106,6 +106,10 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin): Best nodes are defined as relative reduction in impurity. If None then unlimited number of leaf nodes. + low_memory : boolean, default: False + If set, ``predict`` computations use reduced memory but ``predict_cumulative_hazard_function`` + and ``predict_survival_function`` are not implemented. + Attributes ---------- unique_times_ : array of shape = (n_unique_times,) @@ -162,6 +166,7 @@ class SurvivalTree(BaseEstimator, SurvivalAnalysisMixin): ], "random_state": ["random_state"], "max_leaf_nodes": [Interval(Integral, 2, None, closed="left"), None], + "low_memory": ["boolean"], } def __init__( @@ -175,6 +180,7 @@ def __init__( max_features=None, random_state=None, max_leaf_nodes=None, + low_memory=False, ): self.splitter = splitter self.max_depth = max_depth @@ -184,6 +190,7 @@ def __init__( self.max_features = max_features self.random_state = random_state self.max_leaf_nodes = max_leaf_nodes + self.low_memory = low_memory def fit(self, X, y, sample_weight=None, check_input=True): """Build a survival tree from the training set (X, y). @@ -225,13 +232,18 @@ def fit(self, X, y, sample_weight=None, check_input=True): n_samples, self.n_features_in_ = X.shape params = self._check_params(n_samples) - self.n_outputs_ = self.unique_times_.shape[0] - # one "class" for CHF, one for survival function - self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2 + if self.low_memory: + self.n_outputs_ = 1 + # one "class" only, for the sum over the CHF + self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) + else: + self.n_outputs_ = self.unique_times_.shape[0] + # one "class" for CHF, one for survival function + self.n_classes_ = np.ones(self.n_outputs_, dtype=np.intp) * 2 # Build tree self.criterion = "logrank" - criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_) + criterion = LogrankCriterion(self.n_outputs_, n_samples, self.unique_times_, self.is_event_time_) SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS @@ -326,6 +338,14 @@ def _check_max_features(self): self.max_features_ = max_features + def _check_low_memory(self, function): + """Check if `function` is supported in low memory mode and throw if it is not.""" + if self.low_memory: + raise NotImplementedError( + f"{function} is not implemented in low memory mode." + + " run fit with low_memory=False to disable low memory mode." + ) + def _validate_X_predict(self, X, check_input, accept_sparse="csr"): """Validate X whenever one tries to predict""" if check_input: @@ -364,6 +384,13 @@ def predict(self, X, check_input=True): risk_scores : ndarray, shape = (n_samples,) Predicted risk scores. """ + + if self.low_memory: + check_is_fitted(self, "tree_") + X = self._validate_X_predict(X, check_input, accept_sparse="csr") + pred = self.tree_.predict(X) + return pred[..., 0] + chf = self.predict_cumulative_hazard_function(X, check_input, return_array=True) return chf[:, self.is_event_time_].sum(1) @@ -424,6 +451,7 @@ def predict_cumulative_hazard_function(self, X, check_input=True, return_array=F >>> plt.ylim(0, 1) >>> plt.show() """ + self._check_low_memory("predict_cumulative_hazard_function") check_is_fitted(self, "tree_") X = self._validate_X_predict(X, check_input, accept_sparse="csr") @@ -491,6 +519,7 @@ def predict_survival_function(self, X, check_input=True, return_array=False): >>> plt.ylim(0, 1) >>> plt.show() """ + self._check_low_memory("predict_survival_function") check_is_fitted(self, "tree_") X = self._validate_X_predict(X, check_input, accept_sparse="csr") diff --git a/tests/test_forest.py b/tests/test_forest.py index 5e248d5f..53a28f8a 100644 --- a/tests/test_forest.py +++ b/tests/test_forest.py @@ -9,6 +9,7 @@ from sksurv.ensemble import ExtraSurvivalTrees, RandomSurvivalForest from sksurv.preprocessing import OneHotEncoder from sksurv.testing import assert_cindex_almost_equal +from sksurv.tree import SurvivalTree FORESTS = [ RandomSurvivalForest, @@ -345,3 +346,47 @@ def test_predict_sparse(make_whas500, forest_cls): assert_array_equal(y_pred, y_pred_csr) assert_array_equal(y_cum_h_csr, y_cum_h) assert_array_equal(y_surv, y_surv_csr) + + +@pytest.mark.parametrize( + "est_cls,params", + [ + (SurvivalTree, {"min_samples_leaf": 10, "random_state": 42}), + (RandomSurvivalForest, {"n_estimators": 10, "min_samples_leaf": 10, "random_state": 42}), + (ExtraSurvivalTrees, {"n_estimators": 10, "min_samples_leaf": 10, "random_state": 42}), + ], +) +def test_predict_low_memory(make_whas500, est_cls, params): + whas500 = make_whas500(to_numeric=True) + X, y = whas500.x, whas500.y + + X_train, X_test, y_train, _ = train_test_split(X, y, random_state=params["random_state"]) + + est_high = est_cls(**params) + est_high.set_params(low_memory=False) + est_high.fit(X_train, y_train) + pred_high = est_high.predict(X_test) + + est_low = est_cls(**params) + est_low.set_params(low_memory=True) + est_low.fit(X_train, y_train) + pred_low = est_low.predict(X_test) + + assert pred_high.shape[0] == X_test.shape[0] + assert pred_low.shape[0] == X_test.shape[0] + + assert_array_almost_equal(pred_high, pred_low) + + msg = ( + "predict_cumulative_hazard_function is not implemented in low memory mode." + " run fit with low_memory=False to disable low memory mode." + ) + with pytest.raises(NotImplementedError, match=msg): + est_low.predict_cumulative_hazard_function(X_test) + + msg = ( + "predict_survival_function is not implemented in low memory mode." + " run fit with low_memory=False to disable low memory mode." + ) + with pytest.raises(NotImplementedError, match=msg): + est_low.predict_survival_function(X_test)