From 255deeddbe4193a5cbf2e9fd33ea227b0e13fadb Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sat, 23 Dec 2023 20:33:03 -0500 Subject: [PATCH] transfer_states_to_host convenience function --- numpyro/infer/mcmc.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 977343802..c2a4972cb 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -9,7 +9,7 @@ import numpy as np -from jax import jit, lax, local_device_count, pmap, random, vmap +from jax import jit, lax, local_device_count, pmap, random, vmap, device_get import jax.numpy as jnp from jax.tree_util import tree_flatten, tree_map @@ -721,6 +721,13 @@ def print_summary(self, prob=0.9, exclude_deterministic=True): "Number of divergences: {}".format(jnp.sum(extra_fields["diverging"])) ) + def transfer_states_to_host(self): + """ + Reduce the memory footprint of collected samples by transfering them to the host device. + """ + self._states = device_get(self._states) + self._states_flat = device_get(self._states_flat) + def __getstate__(self): state = self.__dict__.copy() state["_cache"] = {}