Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neural network layers #16

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mcx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from mcx.model import model
from mcx.model import sample_forward, seed
from mcx.execution import sample, generate
from mcx.hmc import HMC
from mcx.inference.hmc import HMC

__all__ = [
"model",
"seed",
"HMC",
"sample_forward",
"HMC",
"sample",
"generate",
]
9 changes: 9 additions & 0 deletions mcx/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jax import numpy as np

from .constraints import Constraint
from .utils import broadcast_batch_shape


class Distribution(ABC):
Expand Down Expand Up @@ -82,6 +83,14 @@ def sample(
"""
pass

def broadcast_to(self, destination_shape):
"""Broadcast the distribution to a destination shape.

This is used for instance in Neural Network where you want to
broadcast the distribution to match the layer size.
"""
self.batch_shape = broadcast_batch_shape(self.batch_shape, destination_shape)

def forward(
self,
rng_key: jax.random.PRNGKey,
Expand Down
182 changes: 162 additions & 20 deletions mcx/execution.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,112 @@
import jax
import jax.numpy as np
from jax.flatten_util import ravel_pytree
from tqdm import tqdm

from mcx import sample_forward
from mcx.core import compile_to_logpdf


__all__ = ["sample", "generate"]

class sample(object):

def __init__(self, rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
class sample(object):
def __init__(
self, rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs
):
""" Initialize the sampling runtime.
"""
self.runtime = runtime
self.program = program
self.num_chains = num_chains
self.num_warmup = num_warmup
self.rng_key = rng_key

initialize, build_kernel, to_trace = self.runtime
loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, self.num_chains, **kwargs)
init, warmup, build_kernel, to_trace = self.program

print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
self.kernel = jax.jit(kernel)
kernel = jax.jit(kernel)

self.state = initial_state
self.unravel_fn = unravel_fn
self.kernel = kernel
self.state = state
self.to_trace = to_trace
self.unravel_fn = unravel_fn

def take(self, num_samples=1000):
def run(self, num_samples=1000):
_, self.rng_key = jax.random.split(self.rng_key)

@jax.jit
def update_chains(states, rng_key):
def update_chains(state, rng_key):
keys = jax.random.split(rng_key, self.num_chains)
new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, states)
return new_states, states
new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, state)
return new_states

rng_keys = jax.random.split(self.rng_key, num_samples)
last_state, states = jax.lax.scan(update_chains, self.state, rng_keys)
self.state = last_state
state = self.state
chain = []

trace = self.to_trace(states, self.unravel_fn)
rng_keys = jax.random.split(self.rng_key, num_samples)
with tqdm(rng_keys, unit="samples") as progress:
progress.set_description(
"Collecting {:,} samples across {:,} chains".format(
num_samples, self.num_chains
),
refresh=False,
)
for key in progress:
state = update_chains(state, key)
chain.append(state)
self.state = state

trace = self.to_trace(chain, self.unravel_fn)

return trace


def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
def generate(rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs):
""" The generator runtime """

initialize, build_kernel, to_trace = runtime
init, warmup, build_kernel, to_trace = program

print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, num_chains, **kwargs)
print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
kernel = jax.jit(kernel)

Expand All @@ -59,3 +118,86 @@ def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
new_states = jax.vmap(kernel)(keys, state)

yield new_states


def validate_conditioning_variables(model, **kwargs):
""" Check that all variables passed as arguments to the sampler
are random variables or arguments to the sampler. And converserly
that all of the model definition's positional arguments are given
a value.
"""
conditioning_vars = set(kwargs.keys())
model_randvars = set(model.random_variables)
model_args = set(model.arguments)
available_vars = model_randvars.union(model_args)

# The variables passed as an argument to the initialization (variables
# on which the logpdf is conditionned) must be either a random variable
# or an argument to the model definition.
if not available_vars.issuperset(conditioning_vars):
unknown_vars = list(conditioning_vars.difference(available_vars))
unknown_str = ", ".join(unknown_vars)
raise AttributeError(
"You passed a value for {} which are neither random variables nor arguments to the model definition.".format(
unknown_str
)
)

# The user must provide a value for all of the model definition's
# positional arguments.
model_posargs = set(model.posargs)
if model_posargs.difference(conditioning_vars):
missing_vars = model_posargs.difference(conditioning_vars)
missing_str = ", ".join(missing_vars)
raise AttributeError(
"You need to specify a value for the following arguments: {}".format(
missing_str
)
)


def build_loglikelihood(model, **kwargs):
artifact = compile_to_logpdf(model.graph, model.namespace)
logpdf = artifact.compiled_fn
loglikelihood = jax.partial(logpdf, **kwargs)
return loglikelihood


def get_initial_position(rng_key, model, num_chains, **kwargs):
conditioning_vars = set(kwargs.keys())
model_randvars = set(model.random_variables)
to_sample_vars = model_randvars.difference(conditioning_vars)

samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs)
initial_positions = dict((var, samples[var]) for var in to_sample_vars)

# A naive way to go about flattening the positions is to transform the
# dictionary of arrays that contain the parameter value to a list of
# dictionaries, one per position and then unravel the dictionaries.
# However, this approach takes more time than getting the samples in the
# first place.
#
# Luckily, JAX first sorts dictionaries by keys
# (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when
# raveling pytrees. We can thus ravel and stack parameter values in an
# array, sorting by key; this gives our flattened positions. We then build
# a single dictionary that contains the parameters value and use it to get
# the unraveling function using `unravel_pytree`.
positions = np.stack(
[np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1
)

sample_position_dict = {
parameter: values[0] for parameter, values in initial_positions.items()
}
_, unravel_fn = ravel_pytree(sample_position_dict)

return positions, unravel_fn


def flatten_loglikelihood(logpdf, unravel_fn):
def flattened_logpdf(array):
kwargs = unravel_fn(array)
return logpdf(**kwargs)

return flattened_logpdf
113 changes: 0 additions & 113 deletions mcx/hmc.py

This file was deleted.

5 changes: 0 additions & 5 deletions mcx/inference/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@
The Stan Manual [1]_ is a very good reference on automatic tuning of
parameters used in Hamiltonian Monte Carlo.

.. note:
This is a "flat zone": values used to update the step size or the mass
matrix are 1D arrays. Raveling/unraveling logic should happen at a higher
level.

.. [1]: "HMC Algorithm Parameters", Stan Manual
https://mc-stan.org/docs/2_20/reference-manual/hmc-algorithm-parameters.html
"""
Expand Down
Loading