Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of step_warmup #117

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0987f5f
added step_warmup which is can be overloaded when convenient
torfjelde Mar 9, 2023
30c9f12
added step_warmup to docs
torfjelde Mar 9, 2023
7faa73f
Update src/interface.jl
torfjelde Mar 9, 2023
bd0bdc7
introduce new kwarg `num_warmup` to `sample` which uses `step_warmup`
torfjelde Mar 10, 2023
c620cca
updated docs
torfjelde Mar 10, 2023
572a286
allow combination of discard_initial and num_warmup
torfjelde Mar 10, 2023
6b842ee
added docstring for mcmcsample
torfjelde Mar 10, 2023
ca03832
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
0441773
Apply suggestions from code review
torfjelde Mar 10, 2023
ea369ff
Apply suggestions from code review
torfjelde Mar 10, 2023
8e0ca53
Update src/sample.jl
torfjelde Mar 10, 2023
6877978
removed docstring and deferred description of keyword arguments to th…
torfjelde Mar 10, 2023
b3b3148
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ddc5254
Update src/sample.jl
torfjelde Mar 10, 2023
ffbd32f
Update src/sample.jl
torfjelde Mar 10, 2023
87480ff
added num_warmup to common keyword arguments docs
torfjelde Mar 10, 2023
76f2f23
also allow step_warmup for the initial step
torfjelde Mar 10, 2023
c00d0c9
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ef09c19
simplify logic for discarding fffinitial samples
torfjelde Mar 10, 2023
49b8406
Apply suggestions from code review
torfjelde Mar 10, 2023
f005746
also report progress for the discarded samples
torfjelde Mar 10, 2023
9dccd8a
Merge branch 'torfjelde/step-warmup' of github.com:TuringLang/Abstrac…
torfjelde Mar 10, 2023
ff00e6e
Apply suggestions from code review
torfjelde Mar 10, 2023
7ce9f6b
move progress-report to end of for-loop for discard samples
torfjelde Mar 10, 2023
3a217b2
move step_warmup to the inner while loops too
torfjelde Mar 13, 2023
de9bb2c
Update src/sample.jl
torfjelde Mar 13, 2023
85d938f
Apply suggestions from code review
torfjelde Apr 19, 2023
0a667a4
reverted to for-loop
torfjelde Apr 19, 2023
91f5a10
Update src/sample.jl
torfjelde Apr 19, 2023
7603171
added accidentanly removed comment
torfjelde Apr 19, 2023
ef68d04
Update src/sample.jl
torfjelde Apr 19, 2023
25afc66
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 24, 2023
1886fa8
Merge branch 'master' into torfjelde/step-warmup
torfjelde Oct 25, 2023
0ea293a
fixed formatting
torfjelde Oct 26, 2023
6e8f88e
fix typo
torfjelde Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.4.0"
version = "4.4.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
6 changes: 5 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ Common keyword arguments for regular and parallel sampling are:
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step,
i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to
[`AbstractMCMC.step`](@ref).
- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that
if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples.
- `thinning` (default: `1`): factor by which to thin samples.

!!! info
Expand Down
9 changes: 9 additions & 0 deletions docs/src/design.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ the sampling step of the inference method.
AbstractMCMC.step
```

If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading

```@docs
AbstractMCMC.step_warmup
```

which will be used for the first `num_warmup` iterations, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref).
Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above.

## Collecting samples

!!! note
Expand Down
15 changes: 15 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ current `state` of the sampler.
"""
function step end

"""
step_warmup(rng, model, sampler[, state; kwargs...])

Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`.

When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.step`](@ref) in the first
`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref).
This is useful if the sampler has an initial "warmup"-stage that is different from the
standard iteration.

By default, this simply calls [`AbstractMCMC.step`](@ref).
"""
step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...)
step_warmup(rng, model, sampler, state; kwargs...) = step(rng, model, sampler, state; kwargs...)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

"""
samples(sample, model, sampler[, N; kwargs...])

Expand Down
134 changes: 102 additions & 32 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ isdone(rng, model, sampler, samples, state, iteration; kwargs...)
```
where `state` and `iteration` are the current state and iteration of the sampler, respectively.
It should return `true` when sampling should end, and `false` otherwise.

# Keyword arguments

See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
arguments.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
Expand Down Expand Up @@ -77,6 +82,11 @@ end

Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel
using the `parallel` algorithm, and combine them into a single chain.

# Keyword arguments

See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword
arguments.
"""
function StatsBase.sample(
rng::Random.AbstractRNG,
Expand All @@ -91,7 +101,6 @@ function StatsBase.sample(
end

# Default implementations of regular and parallel sampling.

function mcmcsample(
rng::Random.AbstractRNG,
model::AbstractModel,
Expand All @@ -100,14 +109,27 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Sampling",
callback=nothing,
discard_initial=0,
num_warmup::Int=0,
discard_initial::Int=num_warmup,
thinning=1,
chain_type::Type=Any,
kwargs...,
)
# Check the number of requested samples.
N > 0 || error("the number of samples must be ≥ 1")
discard_initial >= 0 ||
throw(ArgumentError("number of discarded samples must be non-negative"))
num_warmup >= 0 ||
throw(ArgumentError("number of warm-up samples must be non-negative"))
Ntotal = thinning * (N - 1) + discard_initial + 1
Ntotal >= num_warmup || throw(
ArgumentError("number of warm-up samples exceeds the total number of samples")
)

# Determine how many samples to drop from `num_warmup` and the
# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup

# Start the timer
start = time()
Expand All @@ -122,40 +144,58 @@ function mcmcsample(
end

# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if num_warmup > 0
step_warmup(rng, model, sampler; kwargs...)
else
step(rng, model, sampler; kwargs...)
end

# Update the progress bar.
itotal = 1
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not needed? At least it's not present on the master branch, it seems.

Suggested change
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not present on the master branch because there we do progress-reporting in the beginning of the for-loop for the discarded samples, and then a final progress-report before we start the "main sampling loop".

Here I've just moved the "final progress-report before we start the main sampling loop" to before all of the loops, and then moved the progres-reporting to the end of the for-loop for the discarded samples (well, I see I have forgotten to move the progress-reporting inside the discarded samples for-loop 😬 Will fix now)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You happy with this @devmotion ?


# Discard initial samples.
for i in 1:discard_initial
# Update the progress bar.
if progress && i >= next_update
ProgressLogging.@logprogress i / Ntotal
next_update = i + threshold
for j in 1:discard_initial
# Obtain the next sample and state.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be accounted for in the progress logger as well (as done currently).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be good now 👍

sample, state = if j ≤ num_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

Suggested change
sample, state = if j num_warmup
sample, state = if j discard_num_warmup

shouldn't it?

Maybe it could even be split into two sequential for loops?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be

I think technically it doesn't matter, right? Since we have either

  1. discard_num_warmup = discard_initial when num_warmup >= discard_initial, or
  2. discard_num_warmup = num_warmup when num_warmup < discard_initial.

In both of those cases we get the same behavior in the above.

But I think for readability's sake, I agree we should make the change! Just pointing out it shouldn't been a cause of a bug.

Maybe it could even be split into two sequential for loops?

Wait what, wasn't that what I had before? 😕

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait what, wasn't that what I had before? confused

Really? I think you used a different logic initially but maybe I misremember 😄 In any case, I guess it does not matter.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean

for i = 1:discard_num_warmup
    # ...
end

for i = discard_num_warmup + 1:discard_initial
    # ...
end

?

Because you're probably right, I don't think I ever did this exactly 😬

I'm preferential to the current code for readability's sake because it means the discard stepping is looks the same as the proper stepping, code-wise.

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
end

# Initialize iteration counter.
i = 1

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...)
samples = save!!(samples, sample, 1, model, sampler, N; kwargs...)
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Update the progress bar.
itotal = 1 + discard_initial
if progress && itotal >= next_update
ProgressLogging.@logprogress itotal / Ntotal
next_update = itotal + threshold
end
# Step through remainder of warmup iterations and save.
i += 1

# Step through the sampler.
for i in 2:N
while i N
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to switch to a while loop here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no! I'll revert it to for-loop 👍

# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Update progress bar.
if progress && (itotal += 1) >= next_update
Expand All @@ -165,7 +205,11 @@ function mcmcsample(
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
Expand All @@ -174,6 +218,9 @@ function mcmcsample(
# Save the sample.
samples = save!!(samples, sample, i, model, sampler, N; kwargs...)

# Increment iteration counter.
i += 1

Comment on lines +224 to +226
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is handled by the for i in 2:N above:

Suggested change
# Increment iteration counter.
i += 1

# Update the progress bar.
if progress && (itotal += 1) >= next_update
ProgressLogging.@logprogress itotal / Ntotal
Expand Down Expand Up @@ -209,44 +256,67 @@ function mcmcsample(
progress=PROGRESS[],
progressname="Convergence sampling",
callback=nothing,
discard_initial=0,
num_warmup=0,
discard_initial=num_warmup,
thinning=1,
kwargs...,
)
# Determine how many samples to drop from `num_warmup` and the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the same/similar error checks as above?

# main sampling process before we start saving samples.
discard_from_warmup = min(num_warmup, discard_initial)
keep_from_warmup = num_warmup - discard_from_warmup

# Start the timer
start = time()
local state

@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if num_warmup > 0
step_warmup(rng, model, sampler; kwargs...)
else
step(rng, model, sampler; kwargs...)
end

# Discard initial samples.
for _ in 1:discard_initial
for j in 1:discard_initial
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if j ≤ discard_num_warmup
step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
end

# Initialize iteration counter.
i = 1

# Run callback.
callback === nothing || callback(rng, model, sampler, sample, state, 1; kwargs...)
callback === nothing || callback(rng, model, sampler, sample, state, i; kwargs...)

# Save the sample.
samples = AbstractMCMC.samples(sample, model, sampler; kwargs...)
samples = save!!(samples, sample, 1, model, sampler; kwargs...)

# Step through the sampler until stopping.
i = 2
samples = save!!(samples, sample, i, model, sampler; kwargs...)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

torfjelde marked this conversation as resolved.
Show resolved Hide resolved
i += 1

while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...)
# Discard thinned samples.
for _ in 1:(thinning - 1)
# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't i be incremented at the top of the loop? Before it was 2 here, now it is 1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes, nice catch!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just reverted the initialization of i to 2.

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
end

# Obtain the next sample and state.
sample, state = step(rng, model, sampler, state; kwargs...)
sample, state = if i ≤ keep_from_warmup
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could merge this with the for-loop above AFAICT?

step_warmup(rng, model, sampler, state; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Run callback.
callback === nothing ||
Expand Down