diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 127a616fe..872ac6622 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -42,7 +42,7 @@ class ELBO: """ Determines whether the ELBO objective can support inference of discrete latent variables. - Subclasses that are capable of inferring discrete latent variables should override to `True` + Subclasses that are capable of inferring discrete latent variables should override to `True`. """ can_infer_discrete = False @@ -50,7 +50,15 @@ def __init__(self, num_particles=1, vectorize_particles=True): self.num_particles = num_particles self.vectorize_particles = vectorize_particles - def loss(self, rng_key, param_map, model, guide, *args, **kwargs): + def loss( + self, + rng_key, + param_map, + model, + guide, + *args, + **kwargs, + ): """ Evaluates the ELBO with an estimator that uses num_particles many samples/particles. @@ -116,15 +124,30 @@ class Trace_ELBO(ELBO): :param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the num_particles-many particles in parallel. If False use `jax.lax.map`. Defaults to True. + :param multi_sample_guide: Whether to make an assumption that the guide proposes + multiple samples. """ + def __init__( + self, num_particles=1, vectorize_particles=True, multi_sample_guide=False + ): + self.multi_sample_guide = multi_sample_guide + super().__init__( + num_particles=num_particles, vectorize_particles=vectorize_particles + ) + def loss_with_mutable_state( - self, rng_key, param_map, model, guide, *args, **kwargs + self, + rng_key, + param_map, + model, + guide, + *args, + **kwargs, ): def single_particle_elbo(rng_key): params = param_map.copy() model_seed, guide_seed = random.split(rng_key) - seeded_model = seed(model, model_seed) seeded_guide = seed(guide, guide_seed) guide_log_density, guide_trace = log_density( seeded_guide, args, kwargs, param_map @@ -135,19 +158,57 @@ def single_particle_elbo(rng_key): if site["type"] == "mutable" } params.update(mutable_params) - seeded_model = replay(seeded_model, guide_trace) - model_log_density, model_trace = log_density( - seeded_model, args, kwargs, params - ) - check_model_guide_match(model_trace, guide_trace) - _validate_model(model_trace, plate_warning="loose") - mutable_params.update( - { + if self.multi_sample_guide: + plates = { name: site["value"] - for name, site in model_trace.items() - if site["type"] == "mutable" + for name, site in guide_trace.items() + if site["type"] == "plate" } - ) + + def get_model_density(key, latent): + with seed(rng_seed=key), substitute(data={**latent, **plates}): + model_log_density, model_trace = log_density( + model, args, kwargs, params + ) + _validate_model(model_trace, plate_warning="loose") + return model_log_density + + num_guide_samples = None + for name, site in guide_trace.items(): + if site["type"] == "sample": + num_guide_samples = site["value"].shape[0] + break + if num_guide_samples is None: + raise ValueError("guide is missing `sample` sites.") + seeds = random.split(model_seed, num_guide_samples) + latents = { + name: site["value"] + for name, site in guide_trace.items() + if (site["type"] == "sample" and site["value"].size > 0) + or (site["type"] == "deterministic") + } + model_log_density = vmap(get_model_density)(seeds, latents) + assert model_log_density.ndim == 1 + model_log_density = model_log_density.sum(0) + # log p(z) - log q(z) + elbo_particle = (model_log_density - guide_log_density) / seeds.shape[0] + else: + seeded_model = seed(model, model_seed) + replay_model = replay(seeded_model, guide_trace) + model_log_density, model_trace = log_density( + replay_model, args, kwargs, params + ) + check_model_guide_match(model_trace, guide_trace) + _validate_model(model_trace, plate_warning="loose") + mutable_params.update( + { + name: site["value"] + for name, site in model_trace.items() + if site["type"] == "mutable" + } + ) + # log p(z) - log q(z) + elbo_particle = model_log_density - guide_log_density # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density @@ -155,9 +216,10 @@ def single_particle_elbo(rng_key): if self.num_particles == 1: return elbo_particle, mutable_params else: - raise ValueError( - "Currently, we only support mutable states with num_particles=1." + warnings.warn( + "mutable state is currently ignored when num_particles > 1." ) + return elbo_particle, None else: return elbo_particle, None @@ -288,9 +350,10 @@ def single_particle_elbo(rng_key): if self.num_particles == 1: return elbo_particle, mutable_params else: - raise ValueError( - "Currently, we only support mutable states with num_particles=1." + warnings.warn( + "mutable state is currently ignored when num_particles > 1." ) + return elbo_particle, None else: return elbo_particle, None diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index b79d9da56..4b99302cc 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -7,6 +7,7 @@ import tqdm +import jax from jax import jit, lax, random from jax.example_libraries import optimizers import jax.numpy as jnp @@ -189,9 +190,25 @@ def init(self, rng_key, *args, init_params=None, **kwargs): } if init_params is not None: init_guide_params.update(init_params) - model_trace = trace( - substitute(replay(model_init, guide_trace), init_guide_params) - ).get_trace(*args, **kwargs, **self.static_kwargs) + if getattr(self.loss, "multi_sample_guide", False): + latents = { + name: site["value"][0] + for name, site in guide_trace.items() + if site["type"] == "sample" and site["value"].size > 0 + } + latents.update(init_guide_params) + with trace() as model_trace, substitute(data=latents): + model_init(*args, **kwargs, **self.static_kwargs) + for site in model_trace.values(): + if site["type"] == "mutable": + raise ValueError( + "mutable state in model is not supported for " + "multi-sample guide." + ) + else: + model_trace = trace( + substitute(replay(model_init, guide_trace), init_guide_params) + ).get_trace(*args, **kwargs, **self.static_kwargs) params = {} inv_transforms = {} @@ -363,7 +380,7 @@ def body_fn(svi_state, _): batch = max(num_steps // 20, 1) for i in t: svi_state, loss = jit(body_fn)(svi_state, None) - losses.append(loss) + losses.append(jax.device_get(loss)) if i % batch == 0: if stable_update: valid_losses = [x for x in losses[i - batch :] if x == x] diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index f6f196ffe..0e62bd395 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -463,7 +463,7 @@ def guide(): svi = SVI(model, guide, optim.Adam(0.1), elbo(num_particles=num_particles)) if num_particles > 1: - with pytest.raises(ValueError, match="mutable state"): + with pytest.warns(UserWarning, match="mutable state"): svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) return svi_result = svi.run(random.PRNGKey(0), 1000, stable_update=stable_update) @@ -738,3 +738,22 @@ def guide(difficulty=0.0): for i in range(3): assert_allclose(max_errors[i], 0, atol=atol) + + +def test_multi_sample_guide(): + actual_loc = 3.0 + actual_scale = 2.0 + + def model(): + numpyro.sample("x", dist.Normal(actual_loc, actual_scale)) + + def guide(): + loc = numpyro.param("loc", 0.0) + scale = numpyro.param("scale", 1.0, constraint=constraints.positive) + numpyro.sample("x", dist.Normal(loc, scale).expand([10])) + + svi = SVI(model, guide, optim.Adam(0.1), Trace_ELBO(multi_sample_guide=True)) + svi_results = svi.run(random.PRNGKey(0), 2000) + params = svi_results.params + assert_allclose(params["loc"], actual_loc, rtol=0.1) + assert_allclose(params["scale"], actual_scale, rtol=0.1)