diff --git a/Project.toml b/Project.toml index 60215cec..f57b1ff0 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.3.0" +version = "5.4.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index aabf8d6f..f6dd1cf8 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,9 +71,13 @@ Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step, - where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler -- `discard_initial` (default: `0`): number of initial samples that are discarded + `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration +- `num_warmup` (default: `0`): number of "warm-up" steps to take before the first "regular" step, + i.e. number of times to call [`AbstractMCMC.step_warmup`](@ref) before the first call to + [`AbstractMCMC.step`](@ref). +- `discard_initial` (default: `num_warmup`): number of initial samples that are discarded. Note that + if `discard_initial < num_warmup`, warm-up samples will also be included in the resulting samples. - `thinning` (default: `1`): factor by which to thin samples. - `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref) is passed `initial_state` as the `state` argument. diff --git a/docs/src/design.md b/docs/src/design.md index 0cc524a3..f5becb45 100644 --- a/docs/src/design.md +++ b/docs/src/design.md @@ -63,6 +63,15 @@ the sampling step of the inference method. AbstractMCMC.step ``` +If one also has some special handling of the warmup-stage of sampling, then this can be specified by overloading + +```@docs +AbstractMCMC.step_warmup +``` + +which will be used for the first `num_warmup` iterations, as specified as a keyword argument to [`AbstractMCMC.sample`](@ref). +Note that this is optional; by default it simply calls [`AbstractMCMC.step`](@ref) from above. + ## Collecting samples !!! note diff --git a/src/interface.jl b/src/interface.jl index 928a933d..b58ced99 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -73,6 +73,23 @@ current `state` of the sampler. """ function step end +""" + step_warmup(rng, model, sampler[, state; kwargs...]) + +Return a 2-tuple of the next sample and the next state of the MCMC `sampler` for `model`. + +When sampling using [`sample`](@ref), this takes the place of [`AbstractMCMC.step`](@ref) in the first +`num_warmup` number of iterations, as specified by the `num_warmup` keyword to [`sample`](@ref). +This is useful if the sampler has an initial "warmup"-stage that is different from the +standard iteration. + +By default, this simply calls [`AbstractMCMC.step`](@ref). +""" +step_warmup(rng, model, sampler; kwargs...) = step(rng, model, sampler; kwargs...) +function step_warmup(rng, model, sampler, state; kwargs...) + return step(rng, model, sampler, state; kwargs...) +end + """ samples(sample, model, sampler[, N; kwargs...]) diff --git a/src/sample.jl b/src/sample.jl index d60a0739..6e21f180 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -43,6 +43,11 @@ isdone(rng, model, sampler, samples, state, iteration; kwargs...) ``` where `state` and `iteration` are the current state and iteration of the sampler, respectively. It should return `true` when sampling should end, and `false` otherwise. + +# Keyword arguments + +See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword +arguments. """ function StatsBase.sample( rng::Random.AbstractRNG, @@ -80,6 +85,11 @@ end Sample `nchains` Monte Carlo Markov chains from the `model` with the `sampler` in parallel using the `parallel` algorithm, and combine them into a single chain. + +# Keyword arguments + +See https://turinglang.org/AbstractMCMC.jl/dev/api/#Common-keyword-arguments for common keyword +arguments. """ function StatsBase.sample( rng::Random.AbstractRNG, @@ -94,7 +104,6 @@ function StatsBase.sample( end # Default implementations of regular and parallel sampling. - function mcmcsample( rng::Random.AbstractRNG, model::AbstractModel, @@ -103,7 +112,8 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling", callback=nothing, - discard_initial=0, + num_warmup::Int=0, + discard_initial::Int=num_warmup, thinning=1, chain_type::Type=Any, initial_state=nothing, @@ -111,7 +121,19 @@ function mcmcsample( ) # Check the number of requested samples. N > 0 || error("the number of samples must be ≥ 1") + discard_initial >= 0 || + throw(ArgumentError("number of discarded samples must be non-negative")) + num_warmup >= 0 || + throw(ArgumentError("number of warm-up samples must be non-negative")) Ntotal = thinning * (N - 1) + discard_initial + 1 + Ntotal >= num_warmup || throw( + ArgumentError("number of warm-up samples exceeds the total number of samples") + ) + + # Determine how many samples to drop from `num_warmup` and the + # main sampling process before we start saving samples. + discard_from_warmup = min(num_warmup, discard_initial) + keep_from_warmup = num_warmup - discard_from_warmup # Start the timer start = time() @@ -126,22 +148,41 @@ function mcmcsample( end # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) + sample, state = if num_warmup > 0 + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) + else + step_warmup(rng, model, sampler, initial_state; kwargs...) + end else - step(rng, model, sampler, initial_state; kwargs...) + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end + end + + # Update the progress bar. + itotal = 1 + if progress && itotal >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold end # Discard initial samples. - for i in 1:discard_initial - # Update the progress bar. - if progress && i >= next_update - ProgressLogging.@logprogress i / Ntotal - next_update = i + threshold + for j in 1:discard_initial + # Obtain the next sample and state. + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) end - # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + # Update the progress bar. + if progress && (itotal += 1) >= next_update + ProgressLogging.@logprogress itotal / Ntotal + next_update = itotal + threshold + end end # Run callback. @@ -151,19 +192,16 @@ function mcmcsample( samples = AbstractMCMC.samples(sample, model, sampler, N; kwargs...) samples = save!!(samples, sample, 1, model, sampler, N; kwargs...) - # Update the progress bar. - itotal = 1 + discard_initial - if progress && itotal >= next_update - ProgressLogging.@logprogress itotal / Ntotal - next_update = itotal + threshold - end - # Step through the sampler. for i in 2:N # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Update progress bar. if progress && (itotal += 1) >= next_update @@ -173,7 +211,11 @@ function mcmcsample( end # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Run callback. callback === nothing || @@ -217,11 +259,22 @@ function mcmcsample( progress=PROGRESS[], progressname="Convergence sampling", callback=nothing, - discard_initial=0, + num_warmup=0, + discard_initial=num_warmup, thinning=1, initial_state=nothing, kwargs..., ) + # Check the number of requested samples. + discard_initial >= 0 || + throw(ArgumentError("number of discarded samples must be non-negative")) + num_warmup >= 0 || + throw(ArgumentError("number of warm-up samples must be non-negative")) + + # Determine how many samples to drop from `num_warmup` and the + # main sampling process before we start saving samples. + discard_from_warmup = min(num_warmup, discard_initial) + keep_from_warmup = num_warmup - discard_from_warmup # Start the timer start = time() @@ -229,16 +282,28 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) + sample, state = if num_warmup > 0 + if initial_state === nothing + step_warmup(rng, model, sampler; kwargs...) + else + step_warmup(rng, model, sampler, initial_state; kwargs...) + end else - step(rng, model, sampler, state; kwargs...) + if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, initial_state; kwargs...) + end end # Discard initial samples. - for _ in 1:discard_initial + for j in 1:discard_initial # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if j ≤ num_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end end # Run callback. @@ -250,16 +315,23 @@ function mcmcsample( # Step through the sampler until stopping. i = 2 - while !isdone(rng, model, sampler, samples, state, i; progress=progress, kwargs...) # Discard thinned samples. for _ in 1:(thinning - 1) # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end end # Obtain the next sample and state. - sample, state = step(rng, model, sampler, state; kwargs...) + sample, state = if i ≤ keep_from_warmup + step_warmup(rng, model, sampler, state; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Run callback. callback === nothing || diff --git a/test/sample.jl b/test/sample.jl index dcc87526..7599bd79 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -575,6 +575,45 @@ @test all(chain[i].b == ref_chain[i + discard_initial].b for i in 1:N) end + @testset "Warm-up steps" begin + # Create a chain and discard initial samples. + Random.seed!(1234) + N = 100 + num_warmup = 50 + + # Everything should be discarded here. + chain = sample(MyModel(), MySampler(), N; num_warmup=num_warmup) + @test length(chain) == N + @test !ismissing(chain[1].a) + + # Repeat sampling without discarding initial samples. + # On Julia < 1.6 progress logging changes the global RNG and hence is enabled here. + # https://github.com/TuringLang/AbstractMCMC.jl/pull/102#issuecomment-1142253258 + Random.seed!(1234) + ref_chain = sample( + MyModel(), MySampler(), N + num_warmup; progress=VERSION < v"1.6" + ) + @test all(chain[i].a == ref_chain[i + num_warmup].a for i in 1:N) + @test all(chain[i].b == ref_chain[i + num_warmup].b for i in 1:N) + + # Some other stuff. + Random.seed!(1234) + discard_initial = 10 + chain_warmup = sample( + MyModel(), + MySampler(), + N; + num_warmup=num_warmup, + discard_initial=discard_initial, + ) + @test length(chain_warmup) == N + @test all(chain_warmup[i].a == ref_chain[i + discard_initial].a for i in 1:N) + # Check that the first `num_warmup - discard_initial` samples are warmup samples. + @test all( + chain_warmup[i].is_warmup == (i <= num_warmup - discard_initial) for i in 1:N + ) + end + @testset "Thin chain by a factor of `thinning`" begin # Run a thinned chain with `N` samples thinned by factor of `thinning`. Random.seed!(100) diff --git a/test/utils.jl b/test/utils.jl index 1e29a473..b041b3a7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,8 +3,11 @@ struct MyModel <: AbstractMCMC.AbstractModel end struct MySample{A,B} a::A b::B + is_warmup::Bool end +MySample(a, b) = MySample(a, b, false) + struct MySampler <: AbstractMCMC.AbstractSampler end struct AnotherSampler <: AbstractMCMC.AbstractSampler end @@ -16,6 +19,21 @@ end MyChain(a, b) = MyChain(a, b, NamedTuple()) +function AbstractMCMC.step_warmup( + rng::AbstractRNG, + model::MyModel, + sampler::MySampler, + state::Union{Nothing,Integer}=nothing; + loggers=false, + initial_params=nothing, + kwargs..., +) + transition, state = AbstractMCMC.step( + rng, model, sampler, state; loggers, initial_params, kwargs... + ) + return MySample(transition.a, transition.b, true), state +end + function AbstractMCMC.step( rng::AbstractRNG, model::MyModel,