diff --git a/blackjax/mcmc/barker.py b/blackjax/mcmc/barker.py index 9923bd5f3..86ad595db 100644 --- a/blackjax/mcmc/barker.py +++ b/blackjax/mcmc/barker.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +import jax.scipy as jscipy from jax.flatten_util import ravel_pytree from jax.scipy import stats from jax.tree_util import tree_leaves, tree_map @@ -33,13 +34,15 @@ class BarkerState(NamedTuple): The Barker algorithm takes one position of the chain and returns another position. In order to make computations more efficient, we also store the current log-probability density as well as the current gradient of the - log-probability density. + log-probability density. Finally, it stores the step_size and the inverse_mass_matrix. """ position: ArrayTree logdensity: float logdensity_grad: ArrayTree + step_size: float + inverse_mass_matrix: jnp.Array class BarkerInfo(NamedTuple): @@ -86,36 +89,58 @@ def _compute_acceptance_probability( ) -> float: """Compute the acceptance probability of the Barker's proposal kernel.""" - def ratio_proposal_nd(y, x, log_y, log_x): - num = -_log1pexp(-log_y * (x - y)) - den = -_log1pexp(-log_x * (y - x)) + def ratio_proposal_nd(y, x, log_y, log_x, C_t, C_t_inv): + z = C_t_inv.dot(y - x) + c_x = log_x.dot(C_t) + c_y = log_y.dot(C_t) + + num = _log1pexp(-z * c_x) + den = _log1pexp(z * c_y) return jnp.sum(num - den) + C_t, C_t_inv = _get_mass_matrix_sqrt(state.inverse_mass_matrix) + ratios_proposals = tree_map( ratio_proposal_nd, proposal.position, state.position, proposal.logdensity_grad, state.logdensity_grad, + C_t, + C_t_inv, ) ratio_proposal = sum(tree_leaves(ratios_proposals)) return proposal.logdensity - state.logdensity + ratio_proposal def kernel( - rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float + rng_key: PRNGKey, + state: BarkerState, + logdensity_fn: Callable, + step_size: float, + inverse_mass_matrix: jnp.Array, ) -> tuple[BarkerState, BarkerInfo]: - """Generate a new sample with the MALA kernel.""" + """Generate a new sample with the Barker kernel.""" grad_fn = jax.value_and_grad(logdensity_fn) - key_sample, key_rmh = jax.random.split(rng_key) + mass_matrix_sqrt, _ = _get_mass_matrix_sqrt(inverse_mass_matrix) + proposed_pos = _barker_sample( - key_sample, state.position, state.logdensity_grad, step_size + key_sample, + state.position, + state.logdensity_grad, + step_size, + mass_matrix_sqrt, ) + proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos) proposed_state = BarkerState( - proposed_pos, proposed_logdensity, proposed_logdensity_grad + proposed_pos, + proposed_logdensity, + proposed_logdensity_grad, + step_size, + inverse_mass_matrix, ) log_p_accept = _compute_acceptance_probability(state, proposed_state) @@ -129,8 +154,7 @@ def kernel( def as_top_level_api( - logdensity_fn: Callable, - step_size: float, + logdensity_fn: Callable, step_size: float, inverse_mass_matrix: jnp.Array ) -> SamplingAlgorithm: """Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a Gaussian base kernel. @@ -175,6 +199,8 @@ def as_top_level_api( The log-density function we wish to draw samples from. step_size The value to use for the step size in the symplectic integrator. + inverse_mass_matrix + The matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`). Returns ------- @@ -189,12 +215,12 @@ def init_fn(position: ArrayLikeTree, rng_key=None): return init(position, logdensity_fn) def step_fn(rng_key: PRNGKey, state): - return kernel(rng_key, state, logdensity_fn, step_size) + return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix) return SamplingAlgorithm(init_fn, step_fn) -def _barker_sample_nd(key, mean, a, scale): +def _barker_sample_nd(key, mean, a, scale, C): """ Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function: @@ -214,8 +240,10 @@ def _barker_sample_nd(key, mean, a, scale): a The parameter :math:`a` in the equation above, an Array. This is a skewness parameter. scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above. + The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above. It encodes the step size of the proposal. + C + The inverse sqrt of the mass matrix, an Array. It is not used in the 1D version of Barker's proposal and thus not present in the equation above (technically, absored into the scale). Returns ------- @@ -225,17 +253,18 @@ def _barker_sample_nd(key, mean, a, scale): key1, key2 = jax.random.split(key) z = scale * jax.random.normal(key1, shape=mean.shape) + c = a.dot(C) # Sample b=1 with probability p and 0 with probability 1 - p where # p = 1 / (1 + exp(-a * (z - mean))) - log_p = -_log1pexp(-a * z) + log_p = -_log1pexp(-c * z) b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape) # return mean + z if b == 1 else mean - z - return mean + b * z - (1 - b) * z + return mean + C.dot(b * z + (1 - b) * z) -def _barker_sample(key, mean, a, scale): +def _barker_sample(key, mean, a, scale, C): r""" Sample from a multivariate Barker's proposal distribution for PyTrees. @@ -248,14 +277,15 @@ def _barker_sample(key, mean, a, scale): a The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter. scale - The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above. + The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above. It encodes the step size of the proposal. - + C + The inverse sqrt of the mass matrix, an Array. """ flat_mean, unravel_fn = ravel_pytree(mean) flat_a, _ = ravel_pytree(a) - flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale) + flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale, C) return unravel_fn(flat_sample) @@ -263,6 +293,31 @@ def _log1pexp(a): return jnp.log1p(jnp.exp(a)) +def _get_mass_matrix_sqrt(inverse_mass_matrix): + # want cholesky decomposition of mass matrix (see Appendix G of paper) + ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type] + shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type] + if ndim == 1: # diagonal + inv_mass_matrix_sqrt = jnp.sqrt(inverse_mass_matrix) + mass_matrix_sqrt = jnp.reciprocal(inverse_mass_matrix) + elif ndim == 2: + # inverse mass matrix can be factored into L*L.T. We want the cholesky + # factor (inverse of L.T) of the mass matrix. + L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True) + identity = jnp.identity(shape[0]) + mass_matrix_sqrt = jscipy.linalg.solve_triangular( + L, identity, lower=True, trans=True + ) + inv_mass_matrix_sqrt = L + else: + raise ValueError( + "The mass matrix has the wrong number of dimensions:" + f" expected 1 or 2, got {ndim}." + ) + return mass_matrix_sqrt, inv_mass_matrix_sqrt + + +# TODO: update these def _barker_logpdf(x, mean, a, scale): logpdf = jnp.log(2) + stats.norm.logpdf(x, mean, scale) - _log1pexp(-a * (x - mean)) return logpdf