Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce SurvivalTree.predict's memory use #369

Merged
merged 24 commits into from
Jun 17, 2023
Merged

reduce SurvivalTree.predict's memory use #369

merged 24 commits into from
Jun 17, 2023

Conversation

cpoerschke
Copy link
Contributor

@cpoerschke cpoerschke commented Jun 8, 2023

Checklist

  • pytest passes
  • tests are included
  • code is well formatted
  • documentation renders correctly

What does this implement/fix? Explain your changes

issue encountered: Passing relatively lots of technical data samples to SurvivalTree.predict results in "kernel died" in a notebook environment.

analysis

The pred = self.tree_.predict(X) call in SurvivalTree.predict_cumulative_hazard_function returns an array of (n_samples, self.event_times_, 2) shape which is then returned as a (n_samples, self.event_times_) shape and then SurvivalTree.predict does a .sum(1) on it, resulting in an overall (n_samples) shape return value. The n_samples X self.event_times_ shapes use a lot of memory for larger samples especially if the number of event times is a proportion of samples.

proposed solution

  • Do the self.tree_.predict call on a one-row-at-a-time basis, requiring one (1, self.event_times_, 2) shape for each iteration.
  • Have SurvivalTree.predict_cumulative_hazard_function do the sum-ing for each row, removing the need for a (n_samples, self.event_times_) shape array.

testing

peak memory usage stats inspected before and after the change using memray run ./pr6-demo.py profiling.

code snippet:

#!/usr/bin/env python

import numpy as np
import sksurv.tree
import sksurv.util

n = 12345

verbose = (n <= 10)

feature1 = np.arange(0, n, 1)
feature2 = n - feature1

times = np.arange(0,n) + 1
events = (times < times[len(times)//2])

X = np.vstack((feature1, feature2)).T
y = sksurv.util.Surv.from_arrays(time=times, event=events)

if verbose:
    print(f"X={X}\ny={y}")
else:
    print(f"X.shape={X.shape}\ny.shape={y.shape}")

st = sksurv.tree.SurvivalTree(max_leaf_nodes=100)
print(st.fit(X, y))

risk_scores = st.predict(X)

if verbose:
    print(f"risk_scores={risk_scores}")
else:
    print(f"risk_scores.shape={risk_scores.shape}")

sksurv/tree/tree.py Outdated Show resolved Hide resolved
sksurv/tree/tree.py Outdated Show resolved Hide resolved
@sebp
Copy link
Owner

sebp commented Jun 8, 2023

Thanks for your PR.

I agree that predict is quite heavy on memory usage. Part of it has to do with sklearn's Tree-related code, where predictions are returned via the split criterion's node_value method. It is currently implemented to return the full survival function and cumulative hazard function for each sample, disregarding whether predict, predict_survival_function, or predict_cumulative_hazard_function has been called. I could imagine adding a low-memory option that disables computing survival and cumulative hazard function (CHF) and just returns the event counts (sum over CHF, i.e. what predict returns).

@cpoerschke
Copy link
Contributor Author

... Part of it has to do with ... I could imagine adding a low-memory option that disables computing survival and cumulative hazard function (CHF) and just returns the event counts (sum over CHF, i.e. what predict returns).

Thanks for the context and quick feedback!

I've added a low_memory=False option in the latest commit, though I guess its current use does not disable the computations as such. And of course for any new option there should be test coverage too.

@cpoerschke cpoerschke marked this pull request as ready for review June 8, 2023 16:34
@codecov
Copy link

codecov bot commented Jun 8, 2023

Codecov Report

Patch coverage: 100.00% and no project coverage change.

Comparison is base (bd2e240) 97.94% compared to head (9e1b344) 97.95%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #369   +/-   ##
=======================================
  Coverage   97.94%   97.95%           
=======================================
  Files          37       37           
  Lines        3361     3376   +15     
  Branches      509      511    +2     
=======================================
+ Hits         3292     3307   +15     
  Misses         33       33           
  Partials       36       36           
Impacted Files Coverage Δ
sksurv/ensemble/forest.py 100.00% <100.00%> (ø)
sksurv/tree/tree.py 95.71% <100.00%> (+0.43%) ⬆️

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@cpoerschke
Copy link
Contributor Author

... Part of it has to do with sklearn's Tree-related code, where predictions are returned via the split criterion's node_value method. It is currently implemented to return the full survival function and cumulative hazard function for each sample, disregarding whether predict, predict_survival_function, or predict_cumulative_hazard_function has been called. I could imagine adding a low-memory option that disables computing survival and cumulative hazard function (CHF) and just returns the event counts (sum over CHF, i.e. what predict returns).

So I haven't worked with .pxd and .pyx code before but from a little bit of code reading just now ... is the idea conceptually that in low-memory mode:

@cpoerschke cpoerschke marked this pull request as draft June 9, 2023 11:18
Copy link
Contributor Author

@cpoerschke cpoerschke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added work-in-progress notes inline.

sksurv/tree/tree.py Outdated Show resolved Hide resolved
sksurv/tree/_criterion.pyx Outdated Show resolved Hide resolved
sksurv/tree/_criterion.pyx Outdated Show resolved Hide resolved
sksurv/tree/_criterion.pyx Outdated Show resolved Hide resolved
sksurv/tree/tree.py Outdated Show resolved Hide resolved
sksurv/tree/tree.py Show resolved Hide resolved
sksurv/tree/tree.py Show resolved Hide resolved
sksurv/tree/tree.py Outdated Show resolved Hide resolved
tests/test_tree.py Outdated Show resolved Hide resolved
Comment on lines 788 to 790
# Duplicates values in whas500 leads to assert errors because of
# tie resolution during tree fitting.
# Using a synthetic dataset resolves this issue.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, ties should not cause any problems.

sksurv/tree/tree.py Outdated Show resolved Hide resolved
Copy link
Contributor Author

@cpoerschke cpoerschke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sebp for the feedback! I think I've addressed all points and tests pass locally. Please let me know if there's any additional suggestions. Thank you.

sksurv/tree/tree.py Outdated Show resolved Hide resolved
sksurv/tree/tree.py Show resolved Hide resolved
tests/test_tree.py Outdated Show resolved Hide resolved
@cpoerschke cpoerschke marked this pull request as ready for review June 13, 2023 18:15
@cpoerschke cpoerschke requested a review from sebp June 13, 2023 18:16
sksurv/tree/_criterion.pyx Outdated Show resolved Hide resolved
@sebp sebp self-requested a review June 17, 2023 21:05
@sebp sebp merged commit 53d6261 into sebp:master Jun 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants