Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NUTS out of memory #529

Open
timdhondt1 opened this issue Apr 21, 2023 · 1 comment
Open

NUTS out of memory #529

timdhondt1 opened this issue Apr 21, 2023 · 1 comment

Comments

@timdhondt1
Copy link

timdhondt1 commented Apr 21, 2023

In my project, problems in jitting of the NUTS update step result in an OOM error, indicated by the fact that according to the traceback message, the requested memory scales with the number of iterations (2000 in the reproducing example), while the window_adaptation implementation uses jax.lax.scan, which should not unroll loops and thus memory footprint should not scale with number of iterations.

Similar reports have shown up in the PyMC community using the blackjax NUTS sampler https://discourse.pymc.io/t/out-of-memory-when-using-pm-sampling-jax-sample-blackjax-nuts/11544/2.

I cannot find any obvious reason for a memory leak, what's going on here?
I tried the obvious XLA_PYTHON_CLIENT_MEM_FRACTION and XLA_PYTHON_CLIENT_PREALLOCATE solutions. Any time I increase num_steps too much the out of memory error shows up.

Steps/code to reproduce the bug:

import blackjax
import jax
import jax.numpy as jnp
import time

def logdensity(x):
      key = jax.random.PRNGKey(int(time.time() * 1e7)) # quick hacky prng key, doesnt matter
      return x.sum() - jax.random.normal(key, shape=(100000,)).sum()

rng = jax.random.PRNGKey(10)
adapt = blackjax.window_adaptation(blackjax.nuts, logdensity, initial_step_size=1e-4, num_steps=2000)
(updated_state, sampler_params), _ = adapt.run(rng, jax.random.normal(rng, shape=(100000,)))

Expected result:

memory footprint not scaling with the number of iterations.

Error message:

(running on my laptop gpu NVIDIA P2000 4GB)

Traceback (most recent call last):
  File "blackjax_reproduce.py", line 12, in <module>
    (updated_state, sampler_params), _ = adapt.run(rng, jax.random.normal(rng, shape=(100000,)), num_steps=2000)
  File "/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py", line 842, in run
    last_state, info = jax.lax.scan(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 800000000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.32MiB
              constant allocation:       124B
        maybe_live_out allocation:    9.69GiB
     preallocated temp allocation:   17.18MiB
                 total allocation:    9.71GiB
Peak buffers:
	Buffer 1:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 2:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 3:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 4:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 5:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 6:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 7:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 8:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 9:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 10:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 11:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 12:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 13:
		Size: 762.94MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/dynamic_update_slice" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/kernels.py" source_line=842
		XLA Label: fusion
		Shape: f32[2000,100000]
		==========================

	Buffer 14:
		Size: 3.81MiB
		XLA Label: copy
		Shape: f32[10,100000]
		==========================

	Buffer 15:
		Size: 3.81MiB
		Operator: op_name="jit(scan)/jit(main)/while/body/jit(one_step)/broadcast_in_dim[shape=(10, 100000) broadcast_dimensions=()]" source_file="/home/tim/.local/lib/python3.8/site-packages/blackjax/mcmc/termination.py" source_line=37
		XLA Label: broadcast
		Shape: f32[10,100000]
		==========================

Blackjax/JAX/jaxlib/Python version information:

python: 3.8.10
jax: 0.4.1
blackjax: 0.9.6

Context for the issue:

No response

@rlouf rlouf added this to the 1.0 milestone May 5, 2023
@junpenglao
Copy link
Member

I have been following similar discussion/issue appears in jaxopt: google/jaxopt#380. The problem is likely related to jitted function cache blowing up, which can be avoided using the newly implemented jax.clear_caches() in jax (see an example in google/jaxopt#380 (comment)).

For us, there are a few things we should do:

  1. [P0] Migrate the memory issue by have an option to invoke jax.clear_caches() in our high level API, for example:
... = blackjax.window_adaptation(..., loop_in_python=True, clear_cache_each_step=True)

This would also means adding a python for loop implementation of scan (loop_in_python kwarg above, which clear_cache_each_step will overwrite it to True), I think numpyro have something similar.

  1. [P1] Make sure blackjax does not produces objects accrue in some of jax's caches. Per @froystig, "these are functions, or type signatures to those functions. For instance, jax.jit(f) sets up a cache that maps f, plus any types (shapes, dtypes) for which it was compiled, to the executable that jax produced. there's a similar cache for jax.cond and jax.while_loop, since those jit their branch functions."

blackjax also calls jit, cond etc often so this is definitely an area to investigate. Overall, if the memory issue is jax caches related, it suggests that we might be generating fresh functions. It might be unavoidable (e.g., cause by log_density implementation from user side), but we should make sure it is not cause by blackjax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants