Skip to content

Commit

Permalink
allow for more general chain method (#1825)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jul 1, 2024
1 parent d40f0e9 commit f4e69eb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
15 changes: 12 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def model(X, y):
sample values returned from the sampler to constrained values that lie within the support
of the sample sites. Additionally, this is used to return values at deterministic sites in
the model.
:param str chain_method: One of 'parallel' (default), 'sequential', 'vectorized'. The method
:param str chain_method: A callable jax transform like `jax.vmap` or one of
'parallel' (default), 'sequential', 'vectorized'. The method
'parallel' is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs),
If there are not enough devices for 'parallel', we fall back to 'sequential' method to draw
chains sequentially. 'vectorized' method is an experimental feature which vectorizes the
Expand Down Expand Up @@ -340,7 +341,11 @@ def __init__(
raise ValueError("thinning must be a positive integer")
self.thinning = thinning
self.postprocess_fn = postprocess_fn
if chain_method not in ["parallel", "vectorized", "sequential"]:
if not callable(chain_method) and chain_method not in [
"parallel",
"vectorized",
"sequential",
]:
raise ValueError(
"Only supporting the following methods to draw chains:"
' "sequential", "parallel", or "vectorized"'
Expand Down Expand Up @@ -471,7 +476,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
collection_size=collection_size,
progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
diagnostics_fn=diagnostics,
num_chains=self.num_chains if self.chain_method == "parallel" else 1,
num_chains=self.num_chains
if (callable(self.chain_method) or self.chain_method == "parallel")
else 1,
)
states, last_val = collect_vals
# Get first argument of type `HMCState`
Expand Down Expand Up @@ -679,6 +686,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
states, last_state = _laxmap(partial_map_fn, map_args)
elif self.chain_method == "parallel":
states, last_state = pmap(partial_map_fn)(map_args)
elif callable(self.chain_method):
states, last_state = self.chain_method(partial_map_fn)(map_args)
else:
assert self.chain_method == "vectorized"
states, last_state = partial_map_fn(map_args)
Expand Down
21 changes: 21 additions & 0 deletions test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,24 @@ def model(data):
kernel = HMCECS(NUTS(model), proxy=proxy_fn)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(0), data)


def test_callable_chain_method():
def model():
x = numpyro.sample("x", dist.Normal(0.0, 2.0))
y = numpyro.sample("y", dist.Normal(0.0, 2.0))
numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
y = hmc_sites["y"]
new_x = dist.Normal(0.8 * (1 - y), jnp.sqrt(0.8)).sample(rng_key)
return {"x": new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"])
mcmc = MCMC(
kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()
assert set(samples.keys()) == {"x", "y"}

0 comments on commit f4e69eb

Please sign in to comment.