Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of step_warmup #117

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open

Addition of step_warmup #117

wants to merge 35 commits into from

Conversation

torfjelde
Copy link
Member

For many samplers, it might be useful to separate between the warmup phase and the sampling phase, e.g. in AdvancedHMC we have an initial phase where we adapt the parameters to the parameters at hand.

Currently, the usual approach to implementing such a warmup stage is to keep track of the iteration + the adaptation stuff internally in the state, but sometimes that can be quite annoying and/or redundant to implement.

I would also argue that separating these is useful on a conceptual level, e.g. even if I interact directly with the stepper-interface, I would now do

for _ = 1:nwarmup
    state = last(AbstractMCMC.step_warmup(rng, model, sampler, state))
end

for i = 1:nsteps
    transition, state = AbstractMCMC.step(rng, model, sampler, state)
    # save
    ...
end

vs.

for i = 1:nwarmup + nsteps
    transition, state = AbstractMCMC.step(rng, model, sampler, state)
    # save
    if !iswarmup(state)
        ...
    end
end

With this PR, for something like MCMCTempering.jl where in the warmup phase we actually just want to take steps with the underlying sampler rather than also include the swaps, we can then just make the step_warmup do so without having to add any notion of iterations in the state, nor without telling the sampler itself about how many warm-up steps we want (instead it's just specified by kwarg in sample, as it should be).

src/interface.jl Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Mar 9, 2023

Codecov Report

Attention: 15 lines in your changes are missing coverage. Please review.

Comparison is base (dfb33b5) 96.87% compared to head (6e8f88e) 92.87%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #117      +/-   ##
==========================================
- Coverage   96.87%   92.87%   -4.00%     
==========================================
  Files           8        8              
  Lines         320      351      +31     
==========================================
+ Hits          310      326      +16     
- Misses         10       25      +15     
Files Coverage Δ
src/interface.jl 84.00% <0.00%> (-11.46%) ⬇️
src/sample.jl 90.99% <75.51%> (-4.89%) ⬇️

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, as the HMC code in Turing, this conflates warmup stages for the sampler with discarding initial samples. In the first case, (usually) you also want to discard these samples but you might even want to discard more samples even after tuning hyperparameters of a sampler.

I also wonder a bit whether AbstractMCMC is the right level for such an abstraction. Or whether, eg it could be done in AdvancedHMC.

@torfjelde
Copy link
Member Author

I think, as the HMC code in Turing, this conflates warmup stages for the sampler with discarding initial samples. In the first case, (usually) you also want to discard these samples but you might even want to discard more samples even after tuning hyperparameters of a sampler.

Oh I definitively agree with this. IMO this wouldn't be a catch-all solution, and you could still argue that in the case of HMC we should stick to the current approach.

But in many cases this sort of behavior would indeed be what the user expects (though I agree with you, I also don't want to remove the allowance of the current discard_initial behavior).

Is it worth introducing a new keyword argument then? Something that is separate from discard_initial, allowing you to define a "burn-in" period, and then a separate num_warmup that has this potentially special behavior?

I also wonder a bit whether AbstractMCMC is the right level for such an abstraction. Or whether, eg it could be done in AdvancedHMC.

I don't think we should deal with the adaptation, etc. itself in AbstractMCMC, but there are sooo many samplers that have some form of initial adaptation that it's IMO worth providing a simple hook that let's people do this.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion
Copy link
Member

Is it worth introducing a new keyword argument then? Something that is separate from discard_initial, allowing you to define a "burn-in" period, and then a separate num_warmup that has this potentially special behavior?

Yes, I think we should keep these options separate. I wonder if discard_initial should apply to both these warmup stages and the potential burn-in period to be able to keep warm-up samples as well, if desired. Or are we absolutely certain that you would never want to inspect these samples?

@torfjelde
Copy link
Member Author

Yes, I think we should keep these options separate. I wonder if discard_initial should apply to both these warmup stages and the potential burn-in period to be able to keep warm-up samples as well, if desired. Or are we absolutely certain that you would never want to inspect these samples?

No, I agree with you there; sometimes it's nice to keep them.

So are we thinking discard_initial = num_warmup with num_warmup = 0 by default?

@devmotion
Copy link
Member

Yes, I think these would be reasonable default values.

@torfjelde
Copy link
Member Author

torfjelde commented Mar 10, 2023

Doneso 👍

Nvm, I forgot we wanted to allow potentially keeping the warmup-samples around..

src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
@torfjelde
Copy link
Member Author

Aight, I've made an attempt at allowing the desired interaction between discard_initial and num_warmup, but it does complicate the mcmcsample a fair bit 😕

I've also added some docstring for mcmcsample. IMO this should be in the documentation as it specifies the default kwargs that will work with all implementers of step. Currently there's no way to figure out that discard_initial is a thing.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion
Copy link
Member

IMO this should be in the documentation

The standard keyword arguments are listed and explained in the documentation: https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments

src/sample.jl Outdated
Comment on lines 94 to 116
"""
mcmcsample(rng, model, sampler, N_or_is_done; kwargs...)

Default implementation of `sample` for a `model` and `sampler`.

# Arguments
- `rng::Random.AbstractRNG`: the random number generator to use.
- `model::AbstractModel`: the model to sample from.
- `sampler::AbstractSampler`: the sampler to use.
- `N::Integer`: the number of samples to draw.

# Keyword arguments
- `progress`: whether to display a progress bar. Defaults to `true`.
- `progressname`: the name of the progress bar. Defaults to `"Sampling"`.
- `callback`: a function that is called after each [`AbstractMCMC.step`](@ref).
Defaults to `nothing`.
- `num_warmup`: number of warmup samples to draw. Defaults to `0`.
- `discard_initial`: number of initial samples to discard. Defaults to `num_warmup`.
- `thinning`: number of samples to discard between samples. Defaults to `1`.
- `chain_type`: the type to pass to [`AbstractMCMC.bundle_samples`](@ref) at the
end of sampling to wrap up the resulting samples nicely. Defaults to `Any`.
- `kwargs...`: Additional keyword arguments to pass on to [`AbstractMCMC.step`](@ref).
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think nobody will look up the docstring for the unexported mcmcsample function, so it feels listing and explaining keyword arguments in https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments is the better approach? And possibly extending the docstring of sample?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaah I was totally unaware!

So I removed this, and then I've just added a section to sample to tell people where to find docs on the default arguments. I personally rarely go to the docs of a package unless I "have" to, so I think it's at least nice to tell the user where to find the info. I'm even partial to putting the stuff about common keywords in the actual docstrings of sample but I'll leave as is for now.

src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated
thinning=1,
chain_type::Type=Any,
kwargs...,
)
# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
Ntotal = thinning * (N - 1) + discard_initial + 1
Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Shouldn't it just stay the same, possibly with some additional checks:

Suggested change
Ntotal = thinning * (N - 1) + discard_initial + num_warmup + 1
discard_initial >= 0 || throw(ArgumentError("number of discarded samples must be non-negative"))
num_warmup >= 0 || throw(ArgumentError("number of warm-up samples must be non-negative"))
Ntotal = thinning * (N - 1) + discard_initial + 1
Ntotal >= num_warmup || throw(ArgumentError("number of warm-up samples exceeds the total number of samples"))

I thought we would do the following:

  • If num_warmup = 0, we just do the same as currently: Sample discard_initial samples that are discarded + the N samples that are returned, possibly after thinning them.
  • If num_warmup > 0, we still return N samples but depending on discard_initial part of the N samples might be samples from the warm-up stage. For instance:
    • If num_warmup = 10, discard_initiial = 0, and N = 100, we would sample in total N samples and return them, whereof the first 10 are warm-up samples.
    • If num_warmup = 10, discard_initial = 10 (the default if you just specify num_warmup), and N = 100, then we would sample N + discard_initial = 110 samples in total and return the last N = 100 of them, so drop all warm-up samples.
    • If num_warmup = 10, discard_initial = 20, and N = 100, then we would sample N + discard_initial = 120 samples in total and return the last N = 100 of them, so we would drop all samples in the warm-up stage and the first 10 samples of the regular sampling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, yes this was left over from my initial implementation that treated discard_initial and num_warmup as "seperate phases".

I thought we would do the following:

Agreed, but isn't this what my impl is currently doing? With the exception of this line above of course.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know, I stopped reviewing at this point and didn't check the rest 😄

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, aight. Well, the rest is supposed to implement exactly what you outlined 😅

I'll see if I can also add some tests.

torfjelde and others added 2 commits March 10, 2023 09:03
Co-authored-by: David Widmann <[email protected]>
src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
src/interface.jl Outdated Show resolved Hide resolved
src/interface.jl Outdated Show resolved Hide resolved
docs/src/design.md Outdated Show resolved Hide resolved
src/interface.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated
Comment on lines 155 to 158
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed? At least it's not present on the master branch, it seems.

Suggested change
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not present on the master branch because there we do progress-reporting in the beginning of the for-loop for the discarded samples, and then a final progress-report before we start the "main sampling loop".

Here I've just moved the "final progress-report before we start the main sampling loop" to before all of the loops, and then moved the progres-reporting to the end of the for-loop for the discarded samples (well, I see I have forgotten to move the progress-reporting inside the discarded samples for-loop 😬 Will fix now)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You happy with this @devmotion ?

src/sample.jl Outdated
next_update = itotal + threshold
# Step through remainder of warmup iterations and save.
i += 1
for _ in 1:keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use the same approach here as well? And only add a if i <= num_warmup around the current step calls?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, actually one thing: should we also do thinning for num_warmup then? 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are different reasonable approaches but we could just keep doing what we've done so far: we discard discard_initial samples and then return the next N samples, thinned by a factor thinning. That is, the returned samples are samples in iteration discard_initial + 1, discard_initial + thinning + 1, discard_initial + 2 * thinning + 1 etc.

src/sample.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

Would you mind having another look @devmotion ?:)

@torfjelde
Copy link
Member Author

Also, should I bump the minor version (currently bumped the patch version, which seems incorrect)?

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems tests are failing?

next_update = i + threshold
for j in 1:discard_initial
# Obtain the next sample and state.
sample, state = if j ≤ num_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

Suggested change
sample, state = if j num_warmup
sample, state = if j discard_num_warmup

shouldn't it?

Maybe it could even be split into two sequential for loops?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

I think technically it doesn't matter, right? Since we have either

  1. discard_num_warmup = discard_initial when num_warmup >= discard_initial, or
  2. discard_num_warmup = num_warmup when num_warmup < discard_initial.

In both of those cases we get the same behavior in the above.

But I think for readability's sake, I agree we should make the change! Just pointing out it shouldn't been a cause of a bug.

Maybe it could even be split into two sequential for loops?

Wait what, wasn't that what I had before? 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait what, wasn't that what I had before? confused

Really? I think you used a different logic initially but maybe I misremember 😄 In any case, I guess it does not matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean

for i = 1:discard_num_warmup
    # ...
end

for i = discard_num_warmup + 1:discard_initial
    # ...
end

?

Because you're probably right, I don't think I ever did this exactly 😬

I'm preferential to the current code for readability's sake because it means the discard stepping is looks the same as the proper stepping, code-wise.

src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated

# Step through the sampler.
for i in 2:N
while i N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to switch to a while loop here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no! I'll revert it to for-loop 👍

src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved

while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...)
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't i be incremented at the top of the loop? Before it was 2 here, now it is 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, nice catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just reverted the initialization of i to 2.

src/sample.jl Outdated Show resolved Hide resolved
src/sample.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
src/sample.jl Outdated Show resolved Hide resolved
torfjelde and others added 2 commits April 19, 2023 09:07
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@devmotion
Copy link
Member

@torfjelde Can you fix the merge conflict?

@torfjelde
Copy link
Member Author

Should be good to go now @devmotion

@torfjelde
Copy link
Member Author

But we're having issues with nightly on x86 so these two last ones won't finish.

Comment on lines +224 to +226
# Increment iteration counter.
i += 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled by the for i in 2:N above:

Suggested change
# Increment iteration counter.
i += 1

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if j ≤ discard_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit inconsistent, you used

Suggested change
sample, state = if j discard_from_warmup
sample, state = if j num_warmup

above. Doesn't matter (both will work) but I think it would be good to be consistent.

thinning=1,
initial_state=nothing,
kwargs...,
)
# Determine how many samples to drop from `num_warmup` and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the same/similar error checks as above?

end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge this with the for-loop above AFAICT?

@torfjelde
Copy link
Member Author

Ran into another place where this will be useful: https://github.com/torfjelde/AutomaticMALA.jl. Avoids having to put num_adapts in the sampler itself, and can be replaced with num_warmup as a kwarg.

(I need to do some more work on this PR; had sort of forgotten about this)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants