You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It is common in Pyro to use mean-field variational distributions with TraceMeanField_ELBO to reduce ELBO estimator variance. However, TraceMeanField_ELBO is too conservative and cannot be used for only part of a model or combined with Pyro's other inference tools, notably enumeration of discrete variables, making it difficult to perform variational inference reliably in models like LDA.
A better long-term approach would be to automatically identify ELBO fragments that admit analytic computation. Most such fragments are determined (almost) nonparametrically by conditional independence properties of the model and guide, so it should be possible in principle to cover a surprisingly wide range of Pyro models.
We can decompose this pattern-matching problem into two stages using Funsor: patterns for recognizing situations within a larger computation where analytic KL divergence and entropy computations may be used, and patterns for actually performing these computations using the backend distribution libraries' preexisting optimized implementations. These patterns could then be used seamlessly within any Funsor-based ELBO implementation, notably pyro.contrib.funsor.infer.TraceEnum_ELBO.
funsor.optimize.optimize already decomposes ELBO computations into conditionally independent fragments, although there are missing details like constant propagation that need to be handled with more generality (see also #163#109).
Thus, at a high level, for the first stage we'll just need to add patterns that rewrite Monte Carlo expectations back to analytic versions. Obviously this is only applicable when we can guarantee that the Monte Carlo measure is drawn from the same distribution, so these patterns would have to live in their own special interpretation:
@dispatched_interpretationdefanalytic_recognizer(cls, *args):
returnanalytic_recognizer.dispatch(cls, *args)(*args)
@analytic_recognizer.register(Integrate, Delta, Distribution, frozenset)defrecognize_analytic_entropy(log_measure, integrand, reduced_vars):
... # check that the rewrite can be performedreturnIntegrate(integrand, integrand, reduced_vars)
For added robustness, analytic_recognizer could be a StatefulInterpretation holding a mapping from Delta funsors to their sampling distribution funsors.
For the second stage, we'll need eager patterns that are evaluated using the backend .entropy or kl implementations:
With these patterns in hand, computing analytic entropy or KL terms in pyro.contrib.funsor.infer.TraceEnum_ELBO shouldn't involve too much beyond using the new analytic_recognizer interpretation when evaluating the final ELBO funsor expression.
The text was updated successfully, but these errors were encountered:
eb8680
changed the title
Exploit opportunities for analytic KL and entropy computations
Exploit opportunities for analytic KL and entropy computations in Pyro ELBOs
Sep 29, 2020
eb8680
changed the title
Exploit opportunities for analytic KL and entropy computations in Pyro ELBOs
Exploit opportunities for analytic KL and entropy computations
Sep 29, 2020
Motivated by @fehiepsi's work on
TraceMeanField_ELBO
in NumPyro pyro-ppl/numpyro#748 and ongoing issues with the LDA example in Pyro. cc @fritzo @martinjankowiakIt is common in Pyro to use mean-field variational distributions with
TraceMeanField_ELBO
to reduce ELBO estimator variance. However,TraceMeanField_ELBO
is too conservative and cannot be used for only part of a model or combined with Pyro's other inference tools, notably enumeration of discrete variables, making it difficult to perform variational inference reliably in models like LDA.A better long-term approach would be to automatically identify ELBO fragments that admit analytic computation. Most such fragments are determined (almost) nonparametrically by conditional independence properties of the model and guide, so it should be possible in principle to cover a surprisingly wide range of Pyro models.
We can decompose this pattern-matching problem into two stages using Funsor: patterns for recognizing situations within a larger computation where analytic KL divergence and entropy computations may be used, and patterns for actually performing these computations using the backend distribution libraries' preexisting optimized implementations. These patterns could then be used seamlessly within any Funsor-based ELBO implementation, notably
pyro.contrib.funsor.infer.TraceEnum_ELBO
.funsor.optimize.optimize
already decomposes ELBO computations into conditionally independent fragments, although there are missing details like constant propagation that need to be handled with more generality (see also #163 #109).Thus, at a high level, for the first stage we'll just need to add patterns that rewrite Monte Carlo expectations back to analytic versions. Obviously this is only applicable when we can guarantee that the Monte Carlo measure is drawn from the same distribution, so these patterns would have to live in their own special interpretation:
For added robustness,
analytic_recognizer
could be aStatefulInterpretation
holding a mapping fromDelta
funsors to their sampling distribution funsors.For the second stage, we'll need
eager
patterns that are evaluated using the backend.entropy
orkl
implementations:With these patterns in hand, computing analytic entropy or KL terms in
pyro.contrib.funsor.infer.TraceEnum_ELBO
shouldn't involve too much beyond using the newanalytic_recognizer
interpretation when evaluating the final ELBO funsor expression.The text was updated successfully, but these errors were encountered: