diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 977343802..c2a4972cb 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -9,7 +9,7 @@ import numpy as np -from jax import jit, lax, local_device_count, pmap, random, vmap +from jax import jit, lax, local_device_count, pmap, random, vmap, device_get import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_map @@ -721,6 +721,13 @@ def print_summary(self, prob=0.9, exclude_deterministic=True): "Number of divergences: {}".format(jnp.sum(extra_fields["diverging"])) ) + def transfer_states_to_host(self): + """ + Reduce the memory footprint of collected samples by transfering them to the host device. + """ + self._states = device_get(self._states) + self._states_flat = device_get(self._states_flat) + def __getstate__(self): state = self.__dict__.copy() state["_cache"] = {}