Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enumerative inference to the inference library #545

Merged
merged 4 commits into from
Nov 4, 2024

Conversation

ztangent
Copy link
Member

@ztangent ztangent commented Oct 21, 2024

This PR introduces the enumerative_inference function, which can be used to perform:

  • Exact inference over models with a finite number of discrete random choices.
  • Grid approximation of continuous target densities (or mixed continuous-discrete targets)
  • Stratified sampling by enumeration over discrete choices and sampling of continuous choices

The motivations for this PR are:

  1. Making it easier to debug inference algorithms by having a (slower) gold-standard inference algorithm to compare against
  2. Providing a simpler entry point for beginner users of Gen (e.g. students used to webPPL and trying to write probmods-style models). Such users often default to importance_sampling with lots of samples, even though there's no guarantee that gives you a good approximation of the posterior.

To use enumerative_inference, users provide an iterator over choice maps and their associated log-volumes, i.e. the log of the volume of sample space that each choice map is associated with (the volume is 1 if all choices are discrete). enumerative_inference then returns a trace and normalized log. probability for each (choices, log_vol) pair in the iterator, along with an estimate of the log marginal likelihood, essentially performing Riemann integration over the space of traces.

A grid of (choices, log_vol) pairs can be constructed using choice_vol_iter, by specifying addresses and a grid of values / intervals to enumerate over for each address.

Documentation has been added (see https://www.gen.dev/docs/previews/PR545/ref/inference/enumerative/), but I still might want to add a simple tutorial showing how enumerative inference can be used to debug a model in cases where importance sampling with the default proposal might be a give a poor approximation to the posterior (e.g. in cases where the observations are highly unlikely under the prior).

@ztangent
Copy link
Member Author

ztangent commented Nov 4, 2024

Have decided to merge this for now and add the tutorial separately when I have more time! There is a usage example in the test cases (test/inference/enumerative.jl) that we can point people to in the meantime.

@ztangent ztangent merged commit 91d798f into master Nov 4, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant