Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include deterministic variables in AutoDelta's sample_posterior #1584

Merged
merged 1 commit into from
May 9, 2023

Conversation

nikmich1
Copy link
Contributor

@nikmich1 nikmich1 commented May 8, 2023

This attempts to fix #951, by including deterministic variables in the output of AutoDelta's sample_posterior

Changes made to AutoDelta's sample_posterior method:

  • After generating latent samples, generate a list of deterministic variables
  • If no deterministic variables are present in the model, return the latent samples as before
  • If there are deterministic variables, create a Predictive instance using the latent samples and generate samples of only the deterministic variables with the return_sites keyword. Since the Predictive instance in the new sample_posterior method requires the model's *args and **kwargs, these are now passed to sample_posterior as well
  • Add the deterministic samples to the return value of the sample_posterior method

Important: Since the Predictive instance is called in the new sample_posterior method, now the model needs to be callable without arguments, if no arguments are passed to sample_posterior.
Therefore, the following existing tests have been slightly modified, such that the model can also be called without passing data

  • test_autoguide_deterministic
  • test_logistic_regression

New tests added:

  • Test, if new sample_posterior method contains deterministic variables (test_autodelta_capture_deterministic_variables)
  • Test shapes, if sample_shape argument is used in the new sample_posterior method (test_autodelta_sample_posterior_with_sample_shape)

@nikmich1 nikmich1 force-pushed the autodelta branch 2 times, most recently from 9c2435c to ff3d75c Compare May 9, 2023 09:43
fehiepsi
fehiepsi previously approved these changes May 9, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!! Thanks for the clear implementation and adding tests, @nikmich1!

@@ -455,12 +456,25 @@ def __call__(self, *args, **kwargs):

return result

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, sample_shape=(), *args, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: *args, sample_shape=(), **kwargs

Copy link
Contributor Author

@nikmich1 nikmich1 May 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the review! Just corrected the order of arguments in sample_posterior.

@fehiepsi fehiepsi merged commit c6fb104 into pyro-ppl:master May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Guide samples for AutoDelta don't contain deterministic variables
2 participants