Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add preconditioning matrix to barker proposal #730

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 75 additions & 20 deletions blackjax/mcmc/barker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.flatten_util import ravel_pytree
from jax.scipy import stats
from jax.tree_util import tree_leaves, tree_map
Expand All @@ -33,13 +34,15 @@ class BarkerState(NamedTuple):
The Barker algorithm takes one position of the chain and returns another
position. In order to make computations more efficient, we also store
the current log-probability density as well as the current gradient of the
log-probability density.
log-probability density. Finally, it stores the step_size and the inverse_mass_matrix.

"""

position: ArrayTree
logdensity: float
logdensity_grad: ArrayTree
step_size: float
inverse_mass_matrix: jnp.Array


class BarkerInfo(NamedTuple):
Expand Down Expand Up @@ -86,36 +89,58 @@ def _compute_acceptance_probability(
) -> float:
"""Compute the acceptance probability of the Barker's proposal kernel."""

def ratio_proposal_nd(y, x, log_y, log_x):
num = -_log1pexp(-log_y * (x - y))
den = -_log1pexp(-log_x * (y - x))
def ratio_proposal_nd(y, x, log_y, log_x, C_t, C_t_inv):
z = C_t_inv.dot(y - x)
c_x = log_x.dot(C_t)
c_y = log_y.dot(C_t)

num = _log1pexp(-z * c_x)
den = _log1pexp(z * c_y)

return jnp.sum(num - den)

C_t, C_t_inv = _get_mass_matrix_sqrt(state.inverse_mass_matrix)

ratios_proposals = tree_map(
ratio_proposal_nd,
proposal.position,
state.position,
proposal.logdensity_grad,
state.logdensity_grad,
C_t,
C_t_inv,
)
ratio_proposal = sum(tree_leaves(ratios_proposals))
return proposal.logdensity - state.logdensity + ratio_proposal

def kernel(
rng_key: PRNGKey, state: BarkerState, logdensity_fn: Callable, step_size: float
rng_key: PRNGKey,
state: BarkerState,
logdensity_fn: Callable,
step_size: float,
inverse_mass_matrix: jnp.Array,
) -> tuple[BarkerState, BarkerInfo]:
"""Generate a new sample with the MALA kernel."""
"""Generate a new sample with the Barker kernel."""
grad_fn = jax.value_and_grad(logdensity_fn)

key_sample, key_rmh = jax.random.split(rng_key)

mass_matrix_sqrt, _ = _get_mass_matrix_sqrt(inverse_mass_matrix)

proposed_pos = _barker_sample(
key_sample, state.position, state.logdensity_grad, step_size
key_sample,
state.position,
state.logdensity_grad,
step_size,
mass_matrix_sqrt,
)

proposed_logdensity, proposed_logdensity_grad = grad_fn(proposed_pos)
proposed_state = BarkerState(
proposed_pos, proposed_logdensity, proposed_logdensity_grad
proposed_pos,
proposed_logdensity,
proposed_logdensity_grad,
step_size,
inverse_mass_matrix,
)

log_p_accept = _compute_acceptance_probability(state, proposed_state)
Expand All @@ -129,8 +154,7 @@ def kernel(


def as_top_level_api(
logdensity_fn: Callable,
step_size: float,
logdensity_fn: Callable, step_size: float, inverse_mass_matrix: jnp.Array
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Barker's proposal :cite:p:`Livingstone2022Barker` kernel with a
Gaussian base kernel.
Expand Down Expand Up @@ -175,6 +199,8 @@ def as_top_level_api(
The log-density function we wish to draw samples from.
step_size
The value to use for the step size in the symplectic integrator.
inverse_mass_matrix
The matrix to use for pre-conditioning (see Appendix G of :cite:p:`Livingstone2022Barker`).

Returns
-------
Expand All @@ -189,12 +215,12 @@ def init_fn(position: ArrayLikeTree, rng_key=None):
return init(position, logdensity_fn)

def step_fn(rng_key: PRNGKey, state):
return kernel(rng_key, state, logdensity_fn, step_size)
return kernel(rng_key, state, logdensity_fn, step_size, inverse_mass_matrix)

return SamplingAlgorithm(init_fn, step_fn)


def _barker_sample_nd(key, mean, a, scale):
def _barker_sample_nd(key, mean, a, scale, C):
"""
Sample from a multivariate Barker's proposal distribution. In 1D, this has the following probability density function:

Expand All @@ -214,8 +240,10 @@ def _barker_sample_nd(key, mean, a, scale):
a
The parameter :math:`a` in the equation above, an Array. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.
C
The inverse sqrt of the mass matrix, an Array. It is not used in the 1D version of Barker's proposal and thus not present in the equation above (technically, absored into the scale).

Returns
-------
Expand All @@ -225,17 +253,18 @@ def _barker_sample_nd(key, mean, a, scale):

key1, key2 = jax.random.split(key)
z = scale * jax.random.normal(key1, shape=mean.shape)
c = a.dot(C)

# Sample b=1 with probability p and 0 with probability 1 - p where
# p = 1 / (1 + exp(-a * (z - mean)))
log_p = -_log1pexp(-a * z)
log_p = -_log1pexp(-c * z)
b = jax.random.bernoulli(key2, p=jnp.exp(log_p), shape=mean.shape)

# return mean + z if b == 1 else mean - z
return mean + b * z - (1 - b) * z
return mean + C.dot(b * z + (1 - b) * z)


def _barker_sample(key, mean, a, scale):
def _barker_sample(key, mean, a, scale, C):
r"""
Sample from a multivariate Barker's proposal distribution for PyTrees.

Expand All @@ -248,21 +277,47 @@ def _barker_sample(key, mean, a, scale):
a
The parameter :math:`a` in the equation above, the same PyTree as `mean`. This is a skewness parameter.
scale
The standard deviation of the normal distribution, a scalar. This corresponds to :math:`\sigma` in the equation above.
The global scale, a scalar. This corresponds to :math:`\\sigma` in the equation above.
It encodes the step size of the proposal.

C
The inverse sqrt of the mass matrix, an Array.
"""

flat_mean, unravel_fn = ravel_pytree(mean)
flat_a, _ = ravel_pytree(a)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale)
flat_sample = _barker_sample_nd(key, flat_mean, flat_a, scale, C)
return unravel_fn(flat_sample)


def _log1pexp(a):
return jnp.log1p(jnp.exp(a))


def _get_mass_matrix_sqrt(inverse_mass_matrix):
# want cholesky decomposition of mass matrix (see Appendix G of paper)
ndim = jnp.ndim(inverse_mass_matrix) # type: ignore[arg-type]
shape = jnp.shape(inverse_mass_matrix)[:1] # type: ignore[arg-type]
if ndim == 1: # diagonal
inv_mass_matrix_sqrt = jnp.sqrt(inverse_mass_matrix)
mass_matrix_sqrt = jnp.reciprocal(inverse_mass_matrix)
elif ndim == 2:
# inverse mass matrix can be factored into L*L.T. We want the cholesky
# factor (inverse of L.T) of the mass matrix.
L = jscipy.linalg.cholesky(inverse_mass_matrix, lower=True)
identity = jnp.identity(shape[0])
mass_matrix_sqrt = jscipy.linalg.solve_triangular(
L, identity, lower=True, trans=True
)
inv_mass_matrix_sqrt = L
else:
raise ValueError(
"The mass matrix has the wrong number of dimensions:"
f" expected 1 or 2, got {ndim}."
)
return mass_matrix_sqrt, inv_mass_matrix_sqrt


# TODO: update these
def _barker_logpdf(x, mean, a, scale):
logpdf = jnp.log(2) + stats.norm.logpdf(x, mean, scale) - _log1pexp(-a * (x - mean))
return logpdf
Expand Down
Loading