Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vectorized_particles to ELBO #1624

Merged
merged 2 commits into from
Aug 15, 2023

Conversation

fehiepsi
Copy link
Member

Allow to use lax.map instead of vmap in ELBO to reduce memory requirement

@@ -33,6 +34,8 @@ class ELBO:

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only for eval or also for training?

Copy link
Member Author

@fehiepsi fehiepsi Aug 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is for both, but typically used for eval, when we require a large number of particles.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

:param vectorize_particles: Whether to use jax.vmap to compute ELBOs over the num_particles-many particles in parallel. If False use jax.lax.map. Defaults to True.

Copy link
Collaborator

@martinjankowiak martinjankowiak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@@ -33,6 +34,8 @@ class ELBO:

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

:param vectorize_particles: Whether to use jax.vmap to compute ELBOs over the num_particles-many particles in parallel. If False use jax.lax.map. Defaults to True.

@@ -108,11 +112,10 @@ class Trace_ELBO(ELBO):

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -311,6 +320,8 @@ class RenyiELBO(ELBO):
Here :math:`\alpha \neq 1`. Default is 0.
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.
:param vectorize_particles: Whether to use `jax.vmap` to obtain elbos over the
particles. If False, we will use `jax.lax.map`. Defaults to True.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@martinjankowiak martinjankowiak merged commit 56b88c3 into pyro-ppl:master Aug 15, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants