Skip to content

Commit

Permalink
Find the right boundary between runtimes and programs (#13)
Browse files Browse the repository at this point in the history
The boundary between programs (sampling algorithms) and runtimes
(executors) was not very clear. I remove any dependence on the model
from the program and responsibilities are now clear. Most of the initialization has been transferred to the runtime. I also improved the
performance of the creation of initial states.
  • Loading branch information
Rémi Louf authored and rlouf committed Apr 3, 2020
1 parent db62a78 commit 9299a7c
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 138 deletions.
4 changes: 2 additions & 2 deletions mcx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from mcx.model import model
from mcx.model import sample_forward, seed
from mcx.execution import sample, generate
from mcx.hmc import HMC
from mcx.inference.hmc import HMC

__all__ = [
"model",
"seed",
"HMC",
"sample_forward",
"HMC",
"sample",
"generate",
]
180 changes: 161 additions & 19 deletions mcx/execution.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,112 @@
import jax
import jax.numpy as np
from jax.flatten_util import ravel_pytree
from tqdm import tqdm

from mcx import sample_forward
from mcx.core import compile_to_logpdf


__all__ = ["sample", "generate"]

class sample(object):

def __init__(self, rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
class sample(object):
def __init__(
self, rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs
):
""" Initialize the sampling runtime.
"""
self.runtime = runtime
self.program = program
self.num_chains = num_chains
self.num_warmup = num_warmup
self.rng_key = rng_key

initialize, build_kernel, to_trace = self.runtime
loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, self.num_chains, **kwargs)
init, warmup, build_kernel, to_trace = self.program

print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
self.kernel = jax.jit(kernel)
kernel = jax.jit(kernel)

self.state = initial_state
self.unravel_fn = unravel_fn
self.kernel = kernel
self.state = state
self.to_trace = to_trace
self.unravel_fn = unravel_fn

def run(self, num_samples=1000):
_, self.rng_key = jax.random.split(self.rng_key)

@jax.jit
def update_chains(states, rng_key):
def update_chains(state, rng_key):
keys = jax.random.split(rng_key, self.num_chains)
new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, states)
return new_states, states
new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, state)
return new_states

rng_keys = jax.random.split(self.rng_key, num_samples)
last_state, states = jax.lax.scan(update_chains, self.state, rng_keys)
self.state = last_state
state = self.state
chain = []

trace = self.to_trace(states, self.unravel_fn)
rng_keys = jax.random.split(self.rng_key, num_samples)
with tqdm(rng_keys, unit="samples") as progress:
progress.set_description(
"Collecting {:,} samples across {:,} chains".format(
num_samples, self.num_chains
),
refresh=False,
)
for key in progress:
state = update_chains(state, key)
chain.append(state)
self.state = state

trace = self.to_trace(chain, self.unravel_fn)

return trace


def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
def generate(rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs):
""" The generator runtime """

initialize, build_kernel, to_trace = runtime
init, warmup, build_kernel, to_trace = program

print("Initialize the sampler\n")

validate_conditioning_variables(model, **kwargs)
loglikelihood = build_loglikelihood(model, **kwargs)

loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, num_chains, **kwargs)
print("Find initial states...")
initial_position, unravel_fn = get_initial_position(
rng_key, model, num_chains, **kwargs
)
loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn)
initial_state = jax.vmap(init, in_axes=(0, None))(
initial_position, jax.value_and_grad(loglikelihood)
)

print("Warmup the chains...")
parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps)

print("Compile the log-likelihood...")
loglikelihood = jax.jit(loglikelihood)

print("Build and compile the inference kernel...")
kernel = build_kernel(loglikelihood, parameters)
kernel = jax.jit(kernel)

Expand All @@ -59,3 +118,86 @@ def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs):
new_states = jax.vmap(kernel)(keys, state)

yield new_states


def validate_conditioning_variables(model, **kwargs):
""" Check that all variables passed as arguments to the sampler
are random variables or arguments to the sampler. And converserly
that all of the model definition's positional arguments are given
a value.
"""
conditioning_vars = set(kwargs.keys())
model_randvars = set(model.random_variables)
model_args = set(model.arguments)
available_vars = model_randvars.union(model_args)

# The variables passed as an argument to the initialization (variables
# on which the logpdf is conditionned) must be either a random variable
# or an argument to the model definition.
if not available_vars.issuperset(conditioning_vars):
unknown_vars = list(conditioning_vars.difference(available_vars))
unknown_str = ", ".join(unknown_vars)
raise AttributeError(
"You passed a value for {} which are neither random variables nor arguments to the model definition.".format(
unknown_str
)
)

# The user must provide a value for all of the model definition's
# positional arguments.
model_posargs = set(model.posargs)
if model_posargs.difference(conditioning_vars):
missing_vars = model_posargs.difference(conditioning_vars)
missing_str = ", ".join(missing_vars)
raise AttributeError(
"You need to specify a value for the following arguments: {}".format(
missing_str
)
)


def build_loglikelihood(model, **kwargs):
artifact = compile_to_logpdf(model.graph, model.namespace)
logpdf = artifact.compiled_fn
loglikelihood = jax.partial(logpdf, **kwargs)
return loglikelihood


def get_initial_position(rng_key, model, num_chains, **kwargs):
conditioning_vars = set(kwargs.keys())
model_randvars = set(model.random_variables)
to_sample_vars = model_randvars.difference(conditioning_vars)

samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs)
initial_positions = dict((var, samples[var]) for var in to_sample_vars)

# A naive way to go about flattening the positions is to transform the
# dictionary of arrays that contain the parameter value to a list of
# dictionaries, one per position and then unravel the dictionaries.
# However, this approach takes more time than getting the samples in the
# first place.
#
# Luckily, JAX first sorts dictionaries by keys
# (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when
# raveling pytrees. We can thus ravel and stack parameter values in an
# array, sorting by key; this gives our flattened positions. We then build
# a single dictionary that contains the parameters value and use it to get
# the unraveling function using `unravel_pytree`.
positions = np.stack(
[np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1
)

sample_position_dict = {
parameter: values[0] for parameter, values in initial_positions.items()
}
_, unravel_fn = ravel_pytree(sample_position_dict)

return positions, unravel_fn


def flatten_loglikelihood(logpdf, unravel_fn):
def flattened_logpdf(array):
kwargs = unravel_fn(array)
return logpdf(**kwargs)

return flattened_logpdf
114 changes: 0 additions & 114 deletions mcx/hmc.py

This file was deleted.

Loading

0 comments on commit 9299a7c

Please sign in to comment.