diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 872ac6622..eb4e11b68 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -174,7 +174,7 @@ def get_model_density(key, latent): return model_log_density num_guide_samples = None - for name, site in guide_trace.items(): + for site in guide_trace.values(): if site["type"] == "sample": num_guide_samples = site["value"].shape[0] break @@ -210,8 +210,6 @@ def get_model_density(key, latent): # log p(z) - log q(z) elbo_particle = model_log_density - guide_log_density - # log p(z) - log q(z) - elbo_particle = model_log_density - guide_log_density if mutable_params: if self.num_particles == 1: return elbo_particle, mutable_params