From 5da6fa50c697e53f8e830b4c779ef2ce8a961f1e Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 14 Mar 2024 18:56:17 +0000 Subject: [PATCH] Add initial SDVI implementation (#1758) --- docs/source/contrib.rst | 12 ++ .../contrib/stochastic_support/__init__.py | 2 + numpyro/contrib/stochastic_support/dcc.py | 147 ++++++++++----- numpyro/contrib/stochastic_support/sdvi.py | 122 ++++++++++++ test/contrib/stochastic_support/test_sdvi.py | 174 ++++++++++++++++++ 5 files changed, 408 insertions(+), 49 deletions(-) create mode 100644 numpyro/contrib/stochastic_support/sdvi.py create mode 100644 test/contrib/stochastic_support/test_sdvi.py diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index 7d225c850..67eb07113 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -77,7 +77,19 @@ SteinVI Kernels Stochastic Support ~~~~~~~~~~~~~~~~~~ +.. autoclass:: numpyro.contrib.stochastic_support.dcc.StochasticSupportInference + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + .. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +.. autoclass:: numpyro.contrib.stochastic_support.sdvi.SDVI :members: :undoc-members: :show-inheritance: diff --git a/numpyro/contrib/stochastic_support/__init__.py b/numpyro/contrib/stochastic_support/__init__.py index 6c1dc37f9..8b99025b6 100644 --- a/numpyro/contrib/stochastic_support/__init__.py +++ b/numpyro/contrib/stochastic_support/__init__.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from numpyro.contrib.stochastic_support.dcc import DCC +from numpyro.contrib.stochastic_support.sdvi import SDVI __all__ = [ "DCC", + "SDVI", ] diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 7f7c68109..13a4d5ce6 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from abc import ABC, abstractmethod from collections import OrderedDict, namedtuple import jax @@ -16,13 +17,15 @@ DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) -class DCC: +class StochasticSupportInference(ABC): """ - Implements the Divide, Conquer, and Combine (DCC) algorithm for models with - stochastic support from [1]. + Base class for running inference in programs with stochastic support. Each subclass + decomposes the input model into so called straight-line programs (SLPs) which are + the different control-flow paths in the model. Inference is then run in each SLP + separately and the results are combined to produce an overall posterior. .. note:: This implementation assumes that all stochastic branching is done based on the - outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`. + outcomes of discrete sampling sites that are annotated with ``infer={"branching": True}``. For example, .. code-block:: python @@ -35,41 +38,18 @@ def model(): mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) - - - **References:** - - 1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, - Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth - :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. - :param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`. - :param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for local inference. Defaults to :class:`~numpyro.infer.NUTS`. :param int num_slp_samples: Number of samples to draw from the prior to discover the straight-line programs (SLPs). :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference on more than `max_slps`. - :param float proposal_scale: Scale parameter for the proposal distribution for - estimating the normalization constant of an SLP. """ - def __init__( - self, - model, - mcmc_kwargs, - kernel_cls=NUTS, - num_slp_samples=1000, - max_slps=124, - proposal_scale=1.0, - ): + def __init__(self, model, num_slp_samples, max_slps): self.model = model - self.kernel_cls = kernel_cls - self.mcmc_kwargs = mcmc_kwargs - self.num_slp_samples = num_slp_samples self.max_slps = max_slps - self.proposal_scale = proposal_scale def _find_slps(self, rng_key, *args, **kwargs): """ @@ -111,7 +91,95 @@ def _get_branching_trace(self, tr): branching_trace[site["name"]] = int(site["value"]) return branching_trace - def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): + @abstractmethod + def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def _combine_inferences( + self, rng_key, inferences, branching_traces, *args, **kwargs + ): + raise NotImplementedError + + def run(self, rng_key, *args, **kwargs): + """ + Run inference on each SLP separately and combine the results. + + :param jax.random.PRNGKey rng_key: Random number generator key. + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ + rng_key, subkey = random.split(rng_key) + branching_traces = self._find_slps(subkey, *args, **kwargs) + + inferences = dict() + for key, bt in branching_traces.items(): + rng_key, subkey = random.split(rng_key) + inferences[key] = self._run_inference(subkey, bt, *args, **kwargs) + + rng_key, subkey = random.split(rng_key) + return self._combine_inferences( + subkey, inferences, branching_traces, *args, **kwargs + ) + + +class DCC(StochasticSupportInference): + """ + Implements the Divide, Conquer, and Combine (DCC) algorithm for models with + stochastic support from [1]. + + **References:** + + 1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, + Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth + + **Example:** + + .. code-block:: python + + def model(): + model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) + if model1 == 0: + mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) + else: + mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + mcmc_kwargs = dict( + num_warmup=500, num_samples=1000 + ) + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + dcc_result = dcc.run(random.PRNGKey(0)) + + :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. + :param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`. + :param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for + local inference. Defaults to :class:`~numpyro.infer.NUTS`. + :param int num_slp_samples: Number of samples to draw from the prior to discover the + straight-line programs (SLPs). + :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference + on more than `max_slps`. + :param float proposal_scale: Scale parameter for the proposal distribution for + estimating the normalization constant of an SLP. + """ + + def __init__( + self, + model, + mcmc_kwargs, + kernel_cls=NUTS, + num_slp_samples=1000, + max_slps=124, + proposal_scale=1.0, + ): + self.kernel_cls = kernel_cls + self.mcmc_kwargs = mcmc_kwargs + + self.proposal_scale = proposal_scale + + super().__init__(model, num_slp_samples, max_slps) + + def _run_inference(self, rng_key, branching_trace, *args, **kwargs): """ Run MCMC on the model conditioned on the given branching trace. """ @@ -122,7 +190,7 @@ def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): return mcmc.get_samples() - def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs): + def _combine_inferences(self, rng_key, samples, branching_traces, *args, **kwargs): """ Weight each SLP proportional to its estimated normalization constant. The normalization constants are estimated using importance sampling with @@ -159,22 +227,3 @@ def log_weight(rng_key, i, slp_model, slp_samples): normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values()))) slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()} return DCCResult(samples, slp_weights) - - def run(self, rng_key, *args, **kwargs): - """ - Run DCC and collect samples for all SLPs. - - :param jax.random.PRNGKey rng_key: Random number generator key. - :param args: Arguments to the model. - :param kwargs: Keyword arguments to the model. - """ - rng_key, subkey = random.split(rng_key) - branching_traces = self._find_slps(subkey, *args, **kwargs) - - samples = dict() - for key, bt in branching_traces.items(): - rng_key, subkey = random.split(rng_key) - samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs) - - rng_key, subkey = random.split(rng_key) - return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs) diff --git a/numpyro/contrib/stochastic_support/sdvi.py b/numpyro/contrib/stochastic_support/sdvi.py new file mode 100644 index 000000000..82c0f2d01 --- /dev/null +++ b/numpyro/contrib/stochastic_support/sdvi.py @@ -0,0 +1,122 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +import jax +import jax.numpy as jnp + +from numpyro.contrib.stochastic_support.dcc import StochasticSupportInference +from numpyro.handlers import condition +from numpyro.infer import ( + SVI, + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, +) +from numpyro.infer.autoguide import AutoNormal + +SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"]) + +VALID_ELBOS = (Trace_ELBO, TraceMeanField_ELBO, TraceEnum_ELBO, TraceGraph_ELBO) + + +class SDVI(StochasticSupportInference): + """ + Implements the Support Decomposition Variational Inference (SDVI) algorithm for models with + stochastic support from [1]. This implementation creates a separate guide for each SLP, trains + the guides separately, and then combines the guides by weighting them proportional to their ELBO + estimates. + + **References:** + + 1. *Rethinking Variational Inference for Probabilistic Programs with Stochastic Support*, + Tim Reichelt, Luke Ong, Tom Rainforth + + **Example:** + + .. code-block:: python + + def model(): + model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) + if model1 == 0: + mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) + else: + mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + sdvi = SDVI(model, numpyro.optim.Adam(step_size=0.001)) + sdvi_result = sdvi.run(random.PRNGKey(0)) + + :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. + :param optimizer: An instance of :class:`~numpyro.optim._NumpyroOptim`, a + ``jax.example_libraries.optimizers.Optimizer`` or an Optax + ``GradientTransformation``. Gets passed to :class:`~numpyro.infer.SVI`. + :param int svi_num_steps: Number of steps to run SVI for each SLP. + :param int combine_elbo_particles: Number of particles to estimate ELBO for computing + SLP weights. + :param guide_init: A constructor for the guide. This should be a callable that returns a + :class:`~numpyro.infer.autoguide.AutoGuide` instance. Defaults to + :class:`~numpyro.infer.autoguide.AutoNormal`. + :param loss: ELBO loss for SVI. Defaults to :class:`~numpyro.infer.Trace_ELBO`. + :param bool svi_progress_bar: Whether to use a progress bar for SVI. + :param int num_slp_samples: Number of samples to draw from the prior to discover the + straight-line programs (SLPs). + :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference + on more than `max_slps`. + """ + + def __init__( + self, + model, + optimizer, + svi_num_steps=1000, + combine_elbo_particles=1000, + guide_init=AutoNormal, + loss=Trace_ELBO(), + svi_progress_bar=False, + num_slp_samples=1000, + max_slps=124, + ): + self.guide_init = guide_init + self.optimizer = optimizer + self.svi_num_steps = svi_num_steps + self.svi_progress_bar = svi_progress_bar + + if not isinstance(loss, VALID_ELBOS): + err_str = ", ".join(x.__name__ for x in VALID_ELBOS) + raise ValueError(f"loss must be an instance of: ({err_str})") + self.loss = loss + self.combine_elbo_particles = combine_elbo_particles + + super().__init__(model, num_slp_samples, max_slps) + + def _run_inference(self, rng_key, branching_trace, *args, **kwargs): + """ + Run SVI on a given SLP defined by its branching trace. + """ + slp_model = condition(self.model, branching_trace) + guide = self.guide_init(slp_model) + svi = SVI(slp_model, guide, self.optimizer, loss=self.loss) + svi_result = svi.run( + rng_key, + self.svi_num_steps, + *args, + progress_bar=self.svi_progress_bar, + **kwargs, + ) + return guide, svi_result.params + + def _combine_inferences(self, rng_key, guides, branching_traces, *args, **kwargs): + """Weight each SLP proportional to its estimated ELBO.""" + elbos = {} + for bt, (guide, param_map) in guides.items(): + slp_model = condition(self.model, branching_traces[bt]) + elbos[bt] = -Trace_ELBO(num_particles=self.combine_elbo_particles).loss( + rng_key, param_map, slp_model, guide, *args, **kwargs + ) + + normalizer = jax.scipy.special.logsumexp(jnp.array(list(elbos.values()))) + slp_weights = {k: jnp.exp(v - normalizer) for k, v in elbos.items()} + return SDVIResult(guides, slp_weights) diff --git a/test/contrib/stochastic_support/test_sdvi.py b/test/contrib/stochastic_support/test_sdvi.py new file mode 100644 index 000000000..41f8331f4 --- /dev/null +++ b/test/contrib/stochastic_support/test_sdvi.py @@ -0,0 +1,174 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from jax import random + +import numpyro +from numpyro import handlers +from numpyro.contrib.stochastic_support.sdvi import SDVI +import numpyro.distributions as dist +from numpyro.infer import ( + RenyiELBO, + Trace_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + TraceMeanField_ELBO, +) +from numpyro.infer.autoguide import ( + AutoBNAFNormal, + AutoDAIS, + AutoDelta, + AutoDiagonalNormal, + AutoGuideList, + AutoIAFNormal, + AutoLaplaceApproximation, + AutoLowRankMultivariateNormal, + AutoMultivariateNormal, + AutoNormal, +) + + +@pytest.mark.parametrize( + "auto_class", + [ + AutoDiagonalNormal, + AutoDAIS, + AutoIAFNormal, + AutoBNAFNormal, + AutoMultivariateNormal, + AutoLaplaceApproximation, + AutoLowRankMultivariateNormal, + AutoNormal, + AutoDelta, + AutoGuideList, + ], +) +def test_autoguides(auto_class): + dim = 2 + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0).expand([dim]).to_event()) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma).to_event(), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y = dist.Normal(0, 1).sample(subkey, (200, dim)) + if auto_class == AutoGuideList: + + def guide_init_fn(model): + guide = AutoGuideList(model) + guide.append(AutoNormal(handlers.block(model, hide=[]))) + return guide + + auto_class = guide_init_fn + + sdvi = SDVI( + model, + optimizer=numpyro.optim.Adam(0.01), + guide_init=auto_class, + svi_num_steps=10, + ) + + rng_key, subkey = random.split(rng_key) + sdvi.run(subkey, y) + + +@pytest.mark.parametrize( + "elbo_class", + [ + Trace_ELBO, + TraceMeanField_ELBO, + TraceEnum_ELBO, + TraceGraph_ELBO, + ], +) +@pytest.mark.parametrize("num_particles", [1, 4]) +def test_elbos(elbo_class, num_particles): + dim = 2 + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0).expand([dim]).to_event()) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma).to_event(), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y = dist.Normal(0, 1).sample(subkey, (200, dim)) + sdvi = SDVI( + model, + optimizer=numpyro.optim.Adam(0.01), + guide_init=AutoNormal, + svi_num_steps=10, + loss=elbo_class(num_particles=num_particles), + ) + + rng_key, subkey = random.split(rng_key) + sdvi.run(subkey, y) + + +@pytest.mark.parametrize("elbo_class", [RenyiELBO]) +@pytest.mark.xfail(raises=ValueError) +def test_fail_elbos(elbo_class): + dim = 2 + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0).expand([dim]).to_event()) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma).to_event(), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y = dist.Normal(0, 1).sample(subkey, (200, dim)) + sdvi = SDVI( + model, + optimizer=numpyro.optim.Adam(0.01), + svi_num_steps=10, + loss=elbo_class(), + ) + + rng_key, subkey = random.split(rng_key) + sdvi.run(subkey, y) + + +def test_progress_bar(): + dim = 2 + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0).expand([dim]).to_event()) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma).to_event(), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y = dist.Normal(0, 1).sample(subkey, (200, dim)) + sdvi = SDVI( + model, + optimizer=numpyro.optim.Adam(0.01), + svi_num_steps=10, + svi_progress_bar=True, + ) + rng_key, subkey = random.split(rng_key) + sdvi.run(subkey, y)