From 9299a7ca51ae34a0c85307c767aec4bb487d4cf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 30 Mar 2020 17:23:29 +0200 Subject: [PATCH] Find the right boundary between runtimes and programs (#13) The boundary between programs (sampling algorithms) and runtimes (executors) was not very clear. I remove any dependence on the model from the program and responsibilities are now clear. Most of the initialization has been transferred to the runtime. I also improved the performance of the creation of initial states. --- mcx/__init__.py | 4 +- mcx/execution.py | 180 ++++++++++++++++++++++++++++++++++++++----- mcx/hmc.py | 114 --------------------------- mcx/inference/hmc.py | 67 ++++++++++++++++ mcx/model.py | 5 +- 5 files changed, 232 insertions(+), 138 deletions(-) delete mode 100644 mcx/hmc.py create mode 100644 mcx/inference/hmc.py diff --git a/mcx/__init__.py b/mcx/__init__.py index e94d6f41..15b9c9de 100644 --- a/mcx/__init__.py +++ b/mcx/__init__.py @@ -1,13 +1,13 @@ from mcx.model import model from mcx.model import sample_forward, seed from mcx.execution import sample, generate -from mcx.hmc import HMC +from mcx.inference.hmc import HMC __all__ = [ "model", "seed", - "HMC", "sample_forward", + "HMC", "sample", "generate", ] diff --git a/mcx/execution.py b/mcx/execution.py index 1cec2e3e..c1031473 100644 --- a/mcx/execution.py +++ b/mcx/execution.py @@ -1,53 +1,112 @@ import jax +import jax.numpy as np +from jax.flatten_util import ravel_pytree +from tqdm import tqdm +from mcx import sample_forward +from mcx.core import compile_to_logpdf + + +__all__ = ["sample", "generate"] -class sample(object): - def __init__(self, rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): +class sample(object): + def __init__( + self, rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs + ): """ Initialize the sampling runtime. """ - self.runtime = runtime + self.program = program self.num_chains = num_chains - self.num_warmup = num_warmup self.rng_key = rng_key - initialize, build_kernel, to_trace = self.runtime - loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, self.num_chains, **kwargs) + init, warmup, build_kernel, to_trace = self.program + + print("Initialize the sampler\n") + + validate_conditioning_variables(model, **kwargs) + loglikelihood = build_loglikelihood(model, **kwargs) + + print("Find initial states...") + initial_position, unravel_fn = get_initial_position( + rng_key, model, num_chains, **kwargs + ) + loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn) + initial_state = jax.vmap(init, in_axes=(0, None))( + initial_position, jax.value_and_grad(loglikelihood) + ) + + print("Warmup the chains...") + parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps) + + print("Compile the log-likelihood...") loglikelihood = jax.jit(loglikelihood) + print("Build and compile the inference kernel...") kernel = build_kernel(loglikelihood, parameters) - self.kernel = jax.jit(kernel) + kernel = jax.jit(kernel) - self.state = initial_state - self.unravel_fn = unravel_fn + self.kernel = kernel + self.state = state self.to_trace = to_trace + self.unravel_fn = unravel_fn def run(self, num_samples=1000): _, self.rng_key = jax.random.split(self.rng_key) @jax.jit - def update_chains(states, rng_key): + def update_chains(state, rng_key): keys = jax.random.split(rng_key, self.num_chains) - new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, states) - return new_states, states + new_states, info = jax.vmap(self.kernel, in_axes=(0, 0))(keys, state) + return new_states - rng_keys = jax.random.split(self.rng_key, num_samples) - last_state, states = jax.lax.scan(update_chains, self.state, rng_keys) - self.state = last_state + state = self.state + chain = [] - trace = self.to_trace(states, self.unravel_fn) + rng_keys = jax.random.split(self.rng_key, num_samples) + with tqdm(rng_keys, unit="samples") as progress: + progress.set_description( + "Collecting {:,} samples across {:,} chains".format( + num_samples, self.num_chains + ), + refresh=False, + ) + for key in progress: + state = update_chains(state, key) + chain.append(state) + self.state = state + + trace = self.to_trace(chain, self.unravel_fn) return trace -def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): +def generate(rng_key, model, program, num_warmup_steps=1000, num_chains=4, **kwargs): """ The generator runtime """ - initialize, build_kernel, to_trace = runtime + init, warmup, build_kernel, to_trace = program + + print("Initialize the sampler\n") + + validate_conditioning_variables(model, **kwargs) + loglikelihood = build_loglikelihood(model, **kwargs) - loglikelihood, initial_state, parameters, unravel_fn = initialize(rng_key, num_chains, **kwargs) + print("Find initial states...") + initial_position, unravel_fn = get_initial_position( + rng_key, model, num_chains, **kwargs + ) + loglikelihood = flatten_loglikelihood(loglikelihood, unravel_fn) + initial_state = jax.vmap(init, in_axes=(0, None))( + initial_position, jax.value_and_grad(loglikelihood) + ) + + print("Warmup the chains...") + parameters, state = warmup(initial_state, loglikelihood, num_warmup_steps) + + print("Compile the log-likelihood...") loglikelihood = jax.jit(loglikelihood) + print("Build and compile the inference kernel...") kernel = build_kernel(loglikelihood, parameters) kernel = jax.jit(kernel) @@ -59,3 +118,86 @@ def generate(rng_key, runtime, num_warmup=1000, num_chains=4, **kwargs): new_states = jax.vmap(kernel)(keys, state) yield new_states + + +def validate_conditioning_variables(model, **kwargs): + """ Check that all variables passed as arguments to the sampler + are random variables or arguments to the sampler. And converserly + that all of the model definition's positional arguments are given + a value. + """ + conditioning_vars = set(kwargs.keys()) + model_randvars = set(model.random_variables) + model_args = set(model.arguments) + available_vars = model_randvars.union(model_args) + + # The variables passed as an argument to the initialization (variables + # on which the logpdf is conditionned) must be either a random variable + # or an argument to the model definition. + if not available_vars.issuperset(conditioning_vars): + unknown_vars = list(conditioning_vars.difference(available_vars)) + unknown_str = ", ".join(unknown_vars) + raise AttributeError( + "You passed a value for {} which are neither random variables nor arguments to the model definition.".format( + unknown_str + ) + ) + + # The user must provide a value for all of the model definition's + # positional arguments. + model_posargs = set(model.posargs) + if model_posargs.difference(conditioning_vars): + missing_vars = model_posargs.difference(conditioning_vars) + missing_str = ", ".join(missing_vars) + raise AttributeError( + "You need to specify a value for the following arguments: {}".format( + missing_str + ) + ) + + +def build_loglikelihood(model, **kwargs): + artifact = compile_to_logpdf(model.graph, model.namespace) + logpdf = artifact.compiled_fn + loglikelihood = jax.partial(logpdf, **kwargs) + return loglikelihood + + +def get_initial_position(rng_key, model, num_chains, **kwargs): + conditioning_vars = set(kwargs.keys()) + model_randvars = set(model.random_variables) + to_sample_vars = model_randvars.difference(conditioning_vars) + + samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs) + initial_positions = dict((var, samples[var]) for var in to_sample_vars) + + # A naive way to go about flattening the positions is to transform the + # dictionary of arrays that contain the parameter value to a list of + # dictionaries, one per position and then unravel the dictionaries. + # However, this approach takes more time than getting the samples in the + # first place. + # + # Luckily, JAX first sorts dictionaries by keys + # (https://github.com/google/jax/blob/master/jaxlib/pytree.cc) when + # raveling pytrees. We can thus ravel and stack parameter values in an + # array, sorting by key; this gives our flattened positions. We then build + # a single dictionary that contains the parameters value and use it to get + # the unraveling function using `unravel_pytree`. + positions = np.stack( + [np.ravel(samples[s]) for s in sorted(initial_positions.keys())], axis=1 + ) + + sample_position_dict = { + parameter: values[0] for parameter, values in initial_positions.items() + } + _, unravel_fn = ravel_pytree(sample_position_dict) + + return positions, unravel_fn + + +def flatten_loglikelihood(logpdf, unravel_fn): + def flattened_logpdf(array): + kwargs = unravel_fn(array) + return logpdf(**kwargs) + + return flattened_logpdf diff --git a/mcx/hmc.py b/mcx/hmc.py deleted file mode 100644 index 7385fe85..00000000 --- a/mcx/hmc.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import NamedTuple - -import jax -from jax import numpy as np -from jax.flatten_util import ravel_pytree - -from mcx import sample_forward -from mcx.core import compile_to_logpdf -from mcx.inference.integrators import velocity_verlet, hmc_proposal -from mcx.inference.kernels import hmc_kernel, HMCState -from mcx.inference.metrics import gaussian_euclidean_metric - - -class HMCParameters(NamedTuple): - step_size: float - num_integration_steps: float - mass_matrix_sqrt: np.DeviceArray - inverse_mass_matrix: np.DeviceArray - - -def HMC(model, step_size=None, num_integration_steps=None, mass_matrix_sqrt=None, inverse_mass_matrix=None, is_mass_matrix_diagonal=True): - - artifact = compile_to_logpdf(model.graph, model.namespace) - logpdf = artifact.compiled_fn - parameters = HMCParameters(step_size, num_integration_steps, mass_matrix_sqrt, inverse_mass_matrix) - - def _flatten_logpdf(logpdf, unravel_fn): - def flattened_logpdf(array): - kwargs = unravel_fn(array) - return logpdf(**kwargs) - return flattened_logpdf - - def initialize(rng_key, num_chains, **kwargs): - """ - kwargs: a dictionary of arguments and variables we condition on and - their value. - """ - - conditioning_vars = set(kwargs.keys()) - model_randvars = set(model.random_variables) - model_args = set(model.arguments) - available_vars = model_randvars.union(model_args) - - # The variables passed as an argument to the initialization (variables - # on which the logpdf is conditionned) must be either a random variable - # or an argument to the model definition. - if not available_vars.issuperset(conditioning_vars): - unknown_vars = list(conditioning_vars.difference(available_vars)) - unknown_str = ", ".join(unknown_vars) - raise AttributeError("You passed a value for {} which are neither random variables nor arguments to the model definition.".format(unknown_str)) - - # The user must provide a value for all of the model definition's - # positional arguments. - model_posargs = set(model.posargs) - if model_posargs.difference(conditioning_vars): - missing_vars = (model_posargs.difference(conditioning_vars)) - missing_str = ", ".join(missing_vars) - raise AttributeError("You need to specify a value for the following arguments: {}".format(missing_str)) - - # Condition on data to obtain the model's log-likelihood - loglikelihood = jax.partial(logpdf, **kwargs) - - # Sample one initial position per chain from the prior - to_sample_vars = model_randvars.difference(conditioning_vars) - samples = sample_forward(rng_key, model, num_samples=num_chains, **kwargs) - initial_positions = dict((var, samples[var]) for var in to_sample_vars) - - positions = [] - for i in range(num_chains): - position = {k: value[i] for k, value in initial_positions.items()} - flat_position, unravel_fn = ravel_pytree(position) - positions.append(flat_position) - positions = np.stack(positions) - - # Transform the likelihood to use a flat array as single argument - flat_loglikelihood = _flatten_logpdf(loglikelihood, unravel_fn) - - # Compute the log probability and gradient to define initial state - logprobs, logprob_grads = jax.vmap(jax.value_and_grad(flat_loglikelihood))(positions) - initial_state = HMCState(positions, logprobs, logprob_grads) - - return flat_loglikelihood, initial_state, parameters, unravel_fn - - def build_kernel(logpdf, parameters): - """Builds the kernel that moves the chain from one point - to the next. - """ - - try: - mass_matrix_sqrt = parameters.mass_matrix_sqrt - inverse_mass_matrix = parameters.inverse_mass_matrix - num_integration_steps = parameters.num_integration_steps - step_size = parameters.step_size - except AttributeError: - AttributeError( - "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size." - ) - - momentum_generator, kinetic_energy = gaussian_euclidean_metric( - mass_matrix_sqrt, inverse_mass_matrix, - ) - integrator = velocity_verlet(logpdf, kinetic_energy) - proposal = hmc_proposal(integrator, step_size, num_integration_steps) - kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) - - return kernel - - def to_trace(states_chain, ravel_fn): - """ Translate the raw chains to a format that can be understood by and - is useful to humans. - """ - return states_chain - - return initialize, build_kernel, to_trace diff --git a/mcx/inference/hmc.py b/mcx/inference/hmc.py new file mode 100644 index 00000000..fd809d9a --- /dev/null +++ b/mcx/inference/hmc.py @@ -0,0 +1,67 @@ +from typing import NamedTuple + +from jax import numpy as np + +from mcx.inference.integrators import velocity_verlet, hmc_proposal +from mcx.inference.kernels import HMCState, hmc_kernel +from mcx.inference.metrics import gaussian_euclidean_metric + + +class HMCParameters(NamedTuple): + step_size: float + num_integration_steps: float + mass_matrix_sqrt: np.DeviceArray + inverse_mass_matrix: np.DeviceArray + + +def HMC( + step_size=None, + num_integration_steps=None, + mass_matrix_sqrt=None, + inverse_mass_matrix=None, + integrator=velocity_verlet, + is_mass_matrix_diagonal=False, +): + + parameters = HMCParameters( + step_size, num_integration_steps, mass_matrix_sqrt, inverse_mass_matrix + ) + + def init(position, value_and_grad): + log_prob, log_prob_grad = value_and_grad(position) + return HMCState(position, log_prob, log_prob_grad) + + def warmup(initial_state, logpdf, num_warmup_steps): + return parameters, initial_state + + def build_kernel(logpdf, parameters): + """Builds the kernel that moves the chain from one point + to the next. + """ + + try: + mass_matrix_sqrt = parameters.mass_matrix_sqrt + inverse_mass_matrix = parameters.inverse_mass_matrix + num_integration_steps = parameters.num_integration_steps + step_size = parameters.step_size + except AttributeError: + AttributeError( + "The Hamiltonian Monte Carlo algorithm requires the following parameters: mass matrix, inverse mass matrix and step size." + ) + + momentum_generator, kinetic_energy = gaussian_euclidean_metric( + mass_matrix_sqrt, inverse_mass_matrix, + ) + integrator_step = integrator(logpdf, kinetic_energy) + proposal = hmc_proposal(integrator_step, step_size, num_integration_steps) + kernel = hmc_kernel(proposal, momentum_generator, kinetic_energy, logpdf) + + return kernel + + def to_trace(states_chain, ravel_fn): + """ Translate the raw chains to a format that can be understood by and + is useful to humans. + """ + return states_chain + + return init, warmup, build_kernel, to_trace diff --git a/mcx/model.py b/mcx/model.py index 7f61fed4..a46d1f54 100644 --- a/mcx/model.py +++ b/mcx/model.py @@ -357,9 +357,8 @@ def sample_forward(rng_key, model: model, num_samples=1000, **kwargs) -> Dict: sampler_fn = jax.jit(sampler_fn) samples = jax.vmap(sampler_fn, in_axes=in_axes, out_axes=out_axes)(*sampler_args) - trace = {} - for arg, arg_samples in zip(model.variables, samples): - trace[arg] = numpy.asarray(arg_samples).T.squeeze() + trace = {arg: numpy.asarray(arg_samples).T.squeeze() for arg, arg_samples in zip(model.variables, samples)} + return trace