-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New Gibbs sampler using condition
#2099
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2099 +/- ##
=======================================
Coverage 0.00% 0.00%
=======================================
Files 21 22 +1
Lines 1367 1502 +135
=======================================
- Misses 1367 1502 +135 ☔ View full report in Codecov by Sentry. |
One more step closer to Turing |
This PR is not ready though (hence why I put it in draft-mode). Some cases are functional, but it needs more work (which I'm currently doing:)) |
Pull Request Test Coverage Report for Build 6930099476
💛 - Coveralls |
available in DynamicPPL
test/mcmc/gibbs_new.jl
Outdated
# `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. | ||
# FIXME: Oooor it is (see tests below). Uncertain. | ||
Random.seed!(100) | ||
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) | ||
chain = sample(gdemo(1.5, 2.0), alg, 10_000) | ||
@test_broken mean(chain[:s]) ≈ 49 / 24 | ||
@test_broken mean(chain[:m]) ≈ 7 / 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yebai Any idea why this fails? It does work with CSMC
in the example below..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main difference I see here is that this example uses CSMC on continuous rather than discrete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I've also tried without the ESS(:m)
locally, and it doesn't make a difference)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see an obvious reason. CSMC
/ PG
's accuracy can often be improved by increasing the number of particles in each SMC sweep. So, this might go away if you use CSMC(50)
or larger values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As me and @yebai discussed, this issues is likely caused by the fact that SMC samplers will now also treat the conditioned variables as observations, which is not quite what we want.
But atm I don't quite see how we can work around this with the existing functionality we have in DynamicPPL 😕
Basically, we have fix
and condition
:
condition
makes the variables be treated as observations, which, as seen above, is bad for likelihood-tempered samplers.fix
avoids variables being treated as observations but it also doesn't include the log-prob of that variable. This then causes issues in models with hierarchical dependencies, e.g.gdemo
where the prior onm
depends on the value ofs
, and so changings
changes the joint also through the prior onm
.
Effectively, we'd need something similar to fix
which also computes the log-prob.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't followed the discussion in this thread but regarding the last point (ESS): This difference (log-likelihood vs log-joint) was one (the?) main reason for the current design of the re-evaluation in Gibbs
, i.e., for gibbs_rerun
:
Lines 134 to 142 in 6649f10
""" | |
gibbs_rerun(prev_alg, alg) | |
Check if the model should be rerun to recompute the log density before sampling with the | |
Gibbs component `alg` and after sampling from Gibbs component `prev_alg`. | |
By default, the function returns `true`. | |
""" | |
gibbs_rerun(prev_alg, alg) = true |
varinfo.logp
when we know that it is not needed but to be safe (and e.g. for ESS
) the default is to re-evaluate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aaalrighty! Adopting the rerun mechanisms from the current Gibbs
seems to have done the trick!
I've also tried a acclogp!!(context, varinfo, logp)
implementation, and I think this is something we should adopt, as it will immediately address the issue of SMC samplers not currently working with @acclogprob!!
+ there are other scenarios where we might want to hook into this, e.g. DebugContext
where we can to "record" these hard-coded acclogprobs for debugging and model checking purposes. I'll make separate PRs for this though. As mentioned above, CSMC should still work even if we're performing resampling for what is technically not observations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, @devmotion didn't see your comment because was working on a train without internet 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that confuses me slightly is that, technically, even if the SMC sampler sees one of the parameters as an observation, AFAIK, it shouldn't break things, no? It might increase the variance, etc. but it should still converge.
To make this absolutely clear, let's look at PG from Andrieu (2010):
Looking at this, it might seem as if we shouldn't include the "prior" probability of gdemo
, since
But this is all done under the assumption that the joint nicely factorizes as
In this scenario,
But in our scenario, the joint does not nicely factorize as above. We have
To get a similar factorization as above, we can of course just define
where we can just drop the contribution of
With this notation, mapping this to the algorithm above is immediate, in which case it's clear that we're targeting the full joint, and so, in the general scenario, we need to include the log-prob contribution of latent variables that are not being sampled by conditional SMC.
And more generally speaking, PG is just a Gibbs sampler with:
- Using some method (in Andrieu (2010) they assume it's available in closed-form)
$$\theta(i) \sim p \big( \theta \mid y_{1:T}, X_{1:T} \propto p(\theta, x_{1:T}, y_{1:T})$$ - Using CSMC
$$X_{1:T}(i) \sim p \big( X_{1:T} \mid y_{1:T}, \theta \big) \propto p(\theta, x_{1:T}, y_{1:T})$$
And the reason why we're not including the contribution from the variable we're targeting with CSMC is because we would then have to adjust the weights later by removing the prior contribution since this is our importance weight; that is, our proposal for
Does that make sense @yebai ?
And finally, whether we're resampling when hitting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also explains why we need to make acclogp!!
accumulate log-probs in the task-local varinfo
and not the "global" one.
test/mcmc/gibbs_new.jl
Outdated
# `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. | ||
# FIXME: Oooor it is (see tests below). Uncertain. | ||
Random.seed!(100) | ||
alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) | ||
chain = sample(gdemo(1.5, 2.0), alg, 10_000) | ||
@test_broken mean(chain[:s]) ≈ 49 / 24 | ||
@test_broken mean(chain[:m]) ≈ 7 / 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see an obvious reason. CSMC
/ PG
's accuracy can often be improved by increasing the number of particles in each SMC sweep. So, this might go away if you use CSMC(50)
or larger values.
the log-prob of the fixed variables
issues with some of the tests for GibbsV2
Tests are at least passing now, but I have some local changes I'm currently working on, so no merge-y yet please. |
This is starting to take shape @devmotion @yebai @sunxd3 After TuringLang/DynamicPPL.jl#587 there's very little reason to use the "current" Gibbs sampler over this experimental one. It provides much more flexibility, in addition to being much easier to debug, etc. I've also encountered issues with linking, etc. in the current impl of Gibbs which then just become a pain to debug. |
Tests are passing here now 👍 Think it's worth merging and making a release with this before we make a breaking release after #2197 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @torfjelde -- happy to merge this since it is fairly self-contained.
One thing to consider is to rename experimental
to from_future
but that is just a fun thing to do.
This is an attempt at a new Gibbs sampler which makes use of the
condition
functionality from DynamicPPL.jlThis should, for the models which are compatible with
condition
, make for a much more flexible approach to Gibbs sampling.The idea is basically as follows: instead of having a single
varinfo
which is shared between all the samplers, we instead have a differentvarinfo
for every sampler involved in the composition. We then subsequentlycondition
themodel
on the othervarinfo
for each "inner"step
.The current implementation is somewhat rough, but it does seem to work!
Open issues
TODO Issue with MH
breaks, but if we make it a filldist or something then it works
TODO Linking isn't quite there (see last comment below)
TODO