-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reduce SurvivalTree.predict's memory use #369
Conversation
Thanks for your PR. I agree that |
Thanks for the context and quick feedback! I've added a |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## master #369 +/- ##
=======================================
Coverage 97.94% 97.95%
=======================================
Files 37 37
Lines 3361 3376 +15
Branches 509 511 +2
=======================================
+ Hits 3292 3307 +15
Misses 33 33
Partials 36 36
☔ View full report in Codecov by Sentry. |
So I haven't worked with .pxd and .pyx code before but from a little bit of code reading just now ... is the idea conceptually that in low-memory mode:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added work-in-progress notes inline.
Resolved Conflicts: sksurv/tree/_criterion.pyx sksurv/tree/tree.py
tests/test_tree.py
Outdated
# Duplicates values in whas500 leads to assert errors because of | ||
# tie resolution during tree fitting. | ||
# Using a synthetic dataset resolves this issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, ties should not cause any problems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @sebp for the feedback! I think I've addressed all points and tests pass locally. Please let me know if there's any additional suggestions. Thank you.
Checklist
What does this implement/fix? Explain your changes
issue encountered: Passing relatively lots of technical data samples to
SurvivalTree.predict
results in "kernel died" in a notebook environment.analysis
The
pred = self.tree_.predict(X)
call inSurvivalTree.predict_cumulative_hazard_function
returns an array of(n_samples, self.event_times_, 2)
shape which is then returned as a(n_samples, self.event_times_)
shape and thenSurvivalTree.predict
does a.sum(1)
on it, resulting in an overall(n_samples)
shape return value. Then_samples X self.event_times_
shapes use a lot of memory for larger samples especially if the number of event times is a proportion of samples.proposed solution
self.tree_.predict
call on a one-row-at-a-time basis, requiring one(1, self.event_times_, 2)
shape for each iteration.SurvivalTree.predict_cumulative_hazard_function
do thesum
-ing for each row, removing the need for a(n_samples, self.event_times_)
shape array.testing
peak memory usage stats inspected before and after the change using
memray run ./pr6-demo.py
profiling.code snippet: