Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Gibbs sampler using condition #2099

Merged
merged 71 commits into from
Apr 21, 2024
Merged

New Gibbs sampler using condition #2099

merged 71 commits into from
Apr 21, 2024

Conversation

torfjelde
Copy link
Member

@torfjelde torfjelde commented Oct 6, 2023

This is an attempt at a new Gibbs sampler which makes use of the condition functionality from DynamicPPL.jl

This should, for the models which are compatible with condition, make for a much more flexible approach to Gibbs sampling.

The idea is basically as follows: instead of having a single varinfo which is shared between all the samplers, we instead have a different varinfo for every sampler involved in the composition. We then subsequently condition the model on the other varinfo for each "inner" step.

The current implementation is somewhat rough, but it does seem to work!

Open issues

TODO Issue with MH

using Turing
@model demo() = x ~ Normal()
sampler = Turing.Experimental.Gibbs(
  Turing.OrderedDict(
    @varname(x) => MH(AdvancedMH.RandomWalkProposal(Normal(0,0.003)))
  )
)

breaks, but if we make it a filldist or something then it works

sampler = Turing.Experimental.Gibbs(
  Turing.OrderedDict(
    @varname(x) => MH(AdvancedMH.RandomWalkProposal(filldist(Normal(0,0.003),1)))
  )
)

TODO Linking isn't quite there (see last comment below)

TODO

@torfjelde
Copy link
Member Author

@yebai

@codecov
Copy link

codecov bot commented Oct 6, 2023

Codecov Report

Attention: Patch coverage is 0% with 145 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (c29d36e) to head (4bde75a).
Report is 1 commits behind head on master.

❗ Current head 4bde75a differs from pull request most recent head 0f30514. Consider uploading reports for the commit 0f30514 to get more accurate results

Files Patch % Lines
src/experimental/gibbs.jl 0.00% 123 Missing ⚠️
src/mcmc/particle_mcmc.jl 0.00% 22 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #2099    +/-   ##
=======================================
  Coverage    0.00%   0.00%            
=======================================
  Files          21      22     +1     
  Lines        1367    1502   +135     
=======================================
- Misses       1367    1502   +135     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@yebai
Copy link
Member

yebai commented Oct 8, 2023

One more step closer to Turing v1.0... we only need to fix docs and particle Gibbs after this PR.

@torfjelde
Copy link
Member Author

This PR is not ready though (hence why I put it in draft-mode). Some cases are functional, but it needs more work (which I'm currently doing:))

@github-actions
Copy link
Contributor

github-actions bot commented Oct 8, 2023

Pull Request Test Coverage Report for Build 6930099476

  • 0 of 121 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/gibbs_new.jl 0 121 0.0%
Totals Coverage Status
Change from base Build 6890657606: 0.0%
Covered Lines: 0
Relevant Lines: 1542

💛 - Coveralls

Comment on lines 112 to 118
# `GibbsV2` does not work with SMC samplers, e.g. `CSMC`.
# FIXME: Oooor it is (see tests below). Uncertain.
Random.seed!(100)
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
@test_broken mean(chain[:s]) ≈ 49 / 24
@test_broken mean(chain[:m]) ≈ 7 / 6
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yebai Any idea why this fails? It does work with CSMC in the example below..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main difference I see here is that this example uses CSMC on continuous rather than discrete.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I've also tried without the ESS(:m) locally, and it doesn't make a difference)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an obvious reason. CSMC / PG's accuracy can often be improved by increasing the number of particles in each SMC sweep. So, this might go away if you use CSMC(50) or larger values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As me and @yebai discussed, this issues is likely caused by the fact that SMC samplers will now also treat the conditioned variables as observations, which is not quite what we want.

But atm I don't quite see how we can work around this with the existing functionality we have in DynamicPPL 😕

Basically, we have fix and condition:

  • condition makes the variables be treated as observations, which, as seen above, is bad for likelihood-tempered samplers.
  • fix avoids variables being treated as observations but it also doesn't include the log-prob of that variable. This then causes issues in models with hierarchical dependencies, e.g. gdemo where the prior on m depends on the value of s, and so changing s changes the joint also through the prior on m.

Effectively, we'd need something similar to fix which also computes the log-prob.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't followed the discussion in this thread but regarding the last point (ESS): This difference (log-likelihood vs log-joint) was one (the?) main reason for the current design of the re-evaluation in Gibbs, i.e., for gibbs_rerun:

Turing.jl/src/mcmc/gibbs.jl

Lines 134 to 142 in 6649f10

"""
gibbs_rerun(prev_alg, alg)
Check if the model should be rerun to recompute the log density before sampling with the
Gibbs component `alg` and after sampling from Gibbs component `prev_alg`.
By default, the function returns `true`.
"""
gibbs_rerun(prev_alg, alg) = true
We try to avoid re-evaluating varinfo.logp when we know that it is not needed but to be safe (and e.g. for ESS) the default is to re-evaluate.

Copy link
Member Author

@torfjelde torfjelde Nov 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaalrighty! Adopting the rerun mechanisms from the current Gibbs seems to have done the trick!

I've also tried a acclogp!!(context, varinfo, logp) implementation, and I think this is something we should adopt, as it will immediately address the issue of SMC samplers not currently working with @acclogprob!! + there are other scenarios where we might want to hook into this, e.g. DebugContext where we can to "record" these hard-coded acclogprobs for debugging and model checking purposes. I'll make separate PRs for this though. As mentioned above, CSMC should still work even if we're performing resampling for what is technically not observations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, @devmotion didn't see your comment because was working on a train without internet 🙃

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that confuses me slightly is that, technically, even if the SMC sampler sees one of the parameters as an observation, AFAIK, it shouldn't break things, no? It might increase the variance, etc. but it should still converge.

To make this absolutely clear, let's look at PG from Andrieu (2010):

Screenshot_20231121_111042

Looking at this, it might seem as if we shouldn't include the "prior" probability of $m$ in gdemo, since $p_{\theta}(x_{1:T} \mid y_{1:T})$ does not include $p(\theta)$.

But this is all done under the assumption that the joint nicely factorizes as
$$p(x_{1:T}, \theta \mid y_{1:T}) \propto p_{\theta}(x_{1:T}, y_{1:T}) p(\theta)$$
In this scenario, $p\big(\theta(i)\big)$ will be the same value for all particles in step (b), and so there's no need to include this in the weights.

But in our scenario, the joint does not nicely factorize as above. We have $\theta = m$ and $x_{1:T} = s$, so
$$p(s, m \mid y_{1:T}) = p(s) p(m \mid s) p(y_{1:T} \mid m, s)$$
To get a similar factorization as above, we can of course just define
$$p_m(s, y_{1:T}) = \frac{p(s, m, y_{1:T})}{p(m)} \propto p(s, m, y_{1:T})$$
where we can just drop the contribution of $p(m)$ since it's constant wrt. $s$ that we're sampling.

With this notation, mapping this to the algorithm above is immediate, in which case it's clear that we're targeting the full joint, and so, in the general scenario, we need to include the log-prob contribution of latent variables that are not being sampled by conditional SMC.

And more generally speaking, PG is just a Gibbs sampler with:

  1. Using some method (in Andrieu (2010) they assume it's available in closed-form)
    $$\theta(i) \sim p \big( \theta \mid y_{1:T}, X_{1:T} \propto p(\theta, x_{1:T}, y_{1:T})$$
  2. Using CSMC
    $$X_{1:T}(i) \sim p \big( X_{1:T} \mid y_{1:T}, \theta \big) \propto p(\theta, x_{1:T}, y_{1:T})$$

And the reason why we're not including the contribution from the variable we're targeting with CSMC is because we would then have to adjust the weights later by removing the prior contribution since this is our importance weight; that is, our proposal for $s$ is $p(s)$ and we're targeting $p(s, m, y_{1:T})$ so our IS weight is
$$w = \frac{p(s, m, y_{1:T})}{p(s)} = \frac{p(s) p(m \mid s) p(y_{1:T} \mid s, m)}{p(s)} = p(m \mid s) p(y_{1:T} \mid s, m)$$

Does that make sense @yebai ?

And finally, whether we're resampling when hitting $\theta$ or not, is not a correctness issue since we're still targeting the same distribution (though it probably does something to the variance of the estimator; not sure if it's good or bad 🤷 )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also explains why we need to make acclogp!! accumulate log-probs in the task-local varinfo and not the "global" one.

test/mcmc/gibbs_new.jl Outdated Show resolved Hide resolved
Comment on lines 112 to 118
# `GibbsV2` does not work with SMC samplers, e.g. `CSMC`.
# FIXME: Oooor it is (see tests below). Uncertain.
Random.seed!(100)
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
@test_broken mean(chain[:s]) ≈ 49 / 24
@test_broken mean(chain[:m]) ≈ 7 / 6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see an obvious reason. CSMC / PG's accuracy can often be improved by increasing the number of particles in each SMC sweep. So, this might go away if you use CSMC(50) or larger values.

@torfjelde
Copy link
Member Author

Tests are at least passing now, but I have some local changes I'm currently working on, so no merge-y yet please.

src/Turing.jl Outdated Show resolved Hide resolved
src/experimental/gibbs.jl Outdated Show resolved Hide resolved
src/mcmc/Inference.jl Outdated Show resolved Hide resolved
test/experimental/gibbs.jl Outdated Show resolved Hide resolved
test/experimental/gibbs.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

This is starting to take shape @devmotion @yebai @sunxd3

After TuringLang/DynamicPPL.jl#587 there's very little reason to use the "current" Gibbs sampler over this experimental one. It provides much more flexibility, in addition to being much easier to debug, etc. I've also encountered issues with linking, etc. in the current impl of Gibbs which then just become a pain to debug.

@torfjelde
Copy link
Member Author

Tests are passing here now 👍

Think it's worth merging and making a release with this before we make a breaking release after #2197

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @torfjelde -- happy to merge this since it is fairly self-contained.

One thing to consider is to rename experimental to from_future but that is just a fun thing to do.

@torfjelde torfjelde merged commit a022dc6 into master Apr 21, 2024
11 checks passed
@torfjelde torfjelde deleted the torfjelde/new-gibbs branch April 21, 2024 09:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants