Skip to content

Commit

Permalink
Update test_barker.py
Browse files Browse the repository at this point in the history
Add invariance test
  • Loading branch information
AdrienCorenflos authored Oct 2, 2024
1 parent ad59aba commit 6e50160
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/mcmc/test_barker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import itertools

import chex
import jax
Expand Down Expand Up @@ -128,6 +129,51 @@ def scaled_logdensity(x_scaled, data, metric):
states2_trans = jnp.array(states2_trans)
assert jnp.allclose(states1, states2_trans)

@parameterized.parameters(
itertools.product([1234, 5678], ["gaussian", "riemannian"])
)
def test_invariance(self, seed, metric):
logpdf = lambda x: -0.5 * jnp.sum(x**2)

n_samples, m_steps = 10_000, 50

key = jax.random.key(seed)
init_key, inference_key = jax.random.split(key, 2)
inference_keys = jax.random.split(inference_key, n_samples)
if metric == "gaussian":
inv_mass_matrix = jnp.ones((2,))
metric = metrics.default_metric(inv_mass_matrix)
else:
# bit of a random metric but we are testing invariance, not efficiency
metric = metrics.gaussian_riemannian(
lambda x: 1 / jnp.sum(1 + jnp.sum(x**2)) * jnp.eye(2)
)

barker = blackjax.barker_proposal(logpdf, 0.5, metric)
init_samples = jax.random.normal(init_key, shape=(n_samples, 2))

def loop(carry, key_):
state, accepted = carry
state, info = barker.step(key_, state)
accepted += info.is_accepted
return (state, accepted), None

def get_samples(init_sample, key_):
init = (barker.init(init_sample), 0)
(out, n_accepted), _ = jax.lax.scan(
loop, init, jax.random.split(key_, m_steps)
)
return out.position, n_accepted / m_steps

samples, total_accepted = jax.vmap(get_samples)(init_samples, inference_keys)
# now we test the distance versus a Gaussian
chex.assert_trees_all_close(
jnp.mean(samples, 0), jnp.zeros((2,)), atol=1e-1, rtol=1e-1
)
chex.assert_trees_all_close(
jnp.cov(samples.T), jnp.eye(2), atol=1e-1, rtol=1e-1
)


if __name__ == "__main__":
absltest.main()

0 comments on commit 6e50160

Please sign in to comment.