-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NUTS out of memory #529
Comments
I have been following similar discussion/issue appears in jaxopt: google/jaxopt#380. The problem is likely related to jitted function cache blowing up, which can be avoided using the newly implemented For us, there are a few things we should do:
... = blackjax.window_adaptation(..., loop_in_python=True, clear_cache_each_step=True) This would also means adding a python for loop implementation of
blackjax also calls |
In my project, problems in jitting of the NUTS update step result in an OOM error, indicated by the fact that according to the traceback message, the requested memory scales with the number of iterations (2000 in the reproducing example), while the
window_adaptation
implementation usesjax.lax.scan
, which should not unroll loops and thus memory footprint should not scale with number of iterations.Similar reports have shown up in the PyMC community using the blackjax NUTS sampler https://discourse.pymc.io/t/out-of-memory-when-using-pm-sampling-jax-sample-blackjax-nuts/11544/2.
I cannot find any obvious reason for a memory leak, what's going on here?
I tried the obvious
XLA_PYTHON_CLIENT_MEM_FRACTION
andXLA_PYTHON_CLIENT_PREALLOCATE
solutions. Any time I increasenum_steps
too much the out of memory error shows up.Steps/code to reproduce the bug:
Expected result:
Error message:
(running on my laptop gpu NVIDIA P2000 4GB)
Blackjax/JAX/jaxlib/Python version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: