Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reducing GPU memory usage #1689

Closed
dbobrovskiy opened this issue Nov 27, 2023 · 3 comments · Fixed by #1707
Closed

Reducing GPU memory usage #1689

dbobrovskiy opened this issue Nov 27, 2023 · 3 comments · Fixed by #1707
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@dbobrovskiy
Copy link

I'm opening this issue following the discussion on the forum: https://forum.pyro.ai/t/reducing-mcmc-memory-usage/5639/6.

The problem is, not-in-place array copying that happens in mcmc.run after the actual sampling might result in an out-of-memory exception even though the sampling itself was successful. First of all, it would be nice if this could be avoided and the arrays could be transferred to CPU before any not-in-place operations.

More generally, the GPU memory can be controlled buy sampling sequentially using post_warmup_state and transferring each batch of samples to CPU before running the next one. However, this doesn't seem to work as expected, and the consequent batches require more memory than the first one (see the output for the code below).

mcmc_samples = [None] * (n_samples // 1000)
# set up MCMC
self.mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=1000, num_chains=n_chains)
for i in range((n_samples) // 1000):
    print(f"Batch {i+1}")
    # run MCMC for 1000 samples
    self.mcmc.run(jax.random.PRNGKey(0), self.spliced, self.unspliced)
    # store samples transferred to CPU
    mcmc_samples[i] = jax.device_put(self.mcmc.get_samples(), jax.devices("cpu")[0])
    # reset the mcmc before running the next batch
    self.mcmc.post_warmup_state = self.mcmc.last_state

the code above results in:

Running MCMC in batches of 1000 samples, 2 batches in total.
First batch will include 1000 warmup samples.
Batch 1
sample: 100%|██████████| 2000/2000 [11:18<00:00,  2.95it/s, 1023 steps of size 5.13e-06. acc. prob=0.85]
Batch 2
sample: 100%|██████████| 1000/1000 [05:48<00:00,  2.87it/s, 1023 steps of size 5.13e-06. acc. prob=0.85]
2023-11-24 14:43:23.854505: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.56GiB (rounded to 2750440192)requested by op 

To summarise,

  1. Could not-in-place operations at the end of sampling be optionally transferred to CPU?
  2. How should one sample sequentially so that memory usage is not increased in the process?
@fehiepsi fehiepsi added the enhancement New feature or request label Nov 28, 2023
@fehiepsi
Copy link
Member

Could you try replacing those two lines by

        self._states = jax.device_get(states)
        self._states_flat = jax.device_get(states_flat)

If it works, then we can introduce a method named transfer_states_to_host() to perform those device_get operator.

@dbobrovskiy
Copy link
Author

Yep, it works!
Neither requires extra GPU memory after sampling nor leads to memory increase throughout sequential samples.

@fehiepsi
Copy link
Member

Thanks! Do you want to make a PR to add a helper doing such work? :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants