From 34eecab05a794921dbb80b716fb959bc7c528291 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Mar 2023 22:17:06 +0000 Subject: [PATCH 01/17] added initial_state as a kwarg --- docs/src/api.md | 2 ++ src/sample.jl | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f0c2c158..de956535 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -75,6 +75,8 @@ Common keyword arguments for regular and parallel sampling are: where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref) + is passed `initial_state` as the `state` argument. !!! info The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). diff --git a/src/sample.jl b/src/sample.jl index 6c9c32ae..d08e449d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -103,6 +103,7 @@ function mcmcsample( discard_initial=0, thinning=1, chain_type::Type=Any, + initial_state=nothing, kwargs..., ) # Check the number of requested samples. @@ -122,7 +123,11 @@ function mcmcsample( end # Obtain the initial sample and state. - sample, state = step(rng, model, sampler; kwargs...) + sample, state = if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Discard initial samples. for i in 1:discard_initial @@ -211,6 +216,7 @@ function mcmcsample( callback=nothing, discard_initial=0, thinning=1, + initial_state=nothing, kwargs..., ) @@ -220,7 +226,11 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. - sample, state = step(rng, model, sampler; kwargs...) + sample, state = if initial_state === nothing + step(rng, model, sampler; kwargs...) + else + step(rng, model, sampler, state; kwargs...) + end # Discard initial samples. for _ in 1:discard_initial From fbd24ca5827c268465f9963ca5ac7c09b13a5002 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 11:38:29 +0100 Subject: [PATCH 02/17] added support for initial_state kwarg in threaded sample --- src/sample.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/sample.jl b/src/sample.jl index d08e449d..1110194d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -299,6 +299,7 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", init_params=nothing, + initial_state=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -379,6 +380,11 @@ function mcmcsample( else init_params[chainidx] end, + initial_state=if initial_state === nothing + nothing + else + initial_state[chainidx] + end kwargs..., ) @@ -408,6 +414,7 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", init_params=nothing, + initial_state=nothing, # TODO: Add support for this here. kwargs..., ) # Check if actually multiple processes are used. From e9bee369f8520d64c602b6625a13a08d39168d7a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Mar 2023 22:29:37 +0000 Subject: [PATCH 03/17] added support for initial_state in distributed sample --- src/sample.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 1110194d..f054598c 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -414,7 +414,7 @@ function mcmcsample( progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", init_params=nothing, - initial_state=nothing, # TODO: Add support for this here. + initial_state=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -433,6 +433,10 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) + # Create initial parameters. + _init_params = init_params === nothing ? fill(nothing, nchains) : init_params + _initial_state = initial_state === nothing ? fill(nothing, nchains) : initial_state + # Set up worker pool. pool = Distributed.CachingPool(Distributed.workers()) @@ -465,7 +469,7 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, init_params=nothing) + function sample_chain(seed, init_params, initial_state) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -477,6 +481,7 @@ function mcmcsample( N; progress=false, init_params=init_params, + initial_state=initial_state, kwargs..., ) @@ -486,11 +491,7 @@ function mcmcsample( # Return the new chain. return chain end - chains = if init_params === nothing - Distributed.pmap(sample_chain, pool, seeds) - else - Distributed.pmap(sample_chain, pool, seeds, init_params) - end + chains = Distributed.pmap(sample_chain, pool, seeds, _init_params, _initial_state) finally # Stop updating the progress bar. progress && put!(channel, false) From 207c77d0b7ee04fe956a770f764a46626ea20cd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 13 Mar 2023 22:37:18 +0000 Subject: [PATCH 04/17] Update src/sample.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index f054598c..0f3bf997 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -491,7 +491,9 @@ function mcmcsample( # Return the new chain. return chain end - chains = Distributed.pmap(sample_chain, pool, seeds, _init_params, _initial_state) + chains = Distributed.pmap( + sample_chain, pool, seeds, _init_params, _initial_state + ) finally # Stop updating the progress bar. progress && put!(channel, false) From b67e51d09ef360765b9ec62156301866554b3709 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:07:52 +0100 Subject: [PATCH 05/17] removed references to _init_params and _initial_params --- src/sample.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 0f3bf997..10912354 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -384,7 +384,7 @@ function mcmcsample( nothing else initial_state[chainidx] - end + end, kwargs..., ) @@ -433,10 +433,6 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Create initial parameters. - _init_params = init_params === nothing ? fill(nothing, nchains) : init_params - _initial_state = initial_state === nothing ? fill(nothing, nchains) : initial_state - # Set up worker pool. pool = Distributed.CachingPool(Distributed.workers()) @@ -492,7 +488,7 @@ function mcmcsample( return chain end chains = Distributed.pmap( - sample_chain, pool, seeds, _init_params, _initial_state + sample_chain, pool, seeds, init_params, initial_state ) finally # Stop updating the progress bar. From e63c46e268b013995647750e52a1e48fbfbe6ef2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:10:42 +0100 Subject: [PATCH 06/17] check correctness of initial states --- src/sample.jl | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 10912354..b03e2bbb 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -323,8 +323,9 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters are `nothing` or of the correct length + # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(init_params, nchains) + check_initial_state(initial_state, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -427,8 +428,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length + # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(init_params, nchains) + check_initial_state(initial_state, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -518,8 +520,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length + # Ensure that initial parameters and states are `nothing` or of the correct length check_initial_params(init_params, nchains) + check_initial_state(initial_state, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -559,7 +562,6 @@ tighten_eltype(x::Vector{Any}) = map(identity, x) "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", ), ) - check_initial_params(::Nothing, n) = nothing function check_initial_params(x::AbstractArray, n) if length(x) != n @@ -572,3 +574,21 @@ function check_initial_params(x::AbstractArray, n) return nothing end + +@nospecialize check_initial_state(x, n) = throw( + ArgumentError( + "initial states must be specified as a vector of length equal to the number of chains or `nothing`", + ), +) +check_initial_state(::Nothing, n) = nothing +function check_initial_state(x::AbstractArray, n) + if length(x) != n + throw( + ArgumentError( + "incorrect number of initial states (expected $n, received $(length(x))" + ), + ) + end + + return nothing +end From 5083d9418d38eb97d236064e569d82ce758927bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:18:49 +0100 Subject: [PATCH 07/17] renamed init_params to initial_params --- src/sample.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index b03e2bbb..ddfad4f8 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -298,7 +298,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", - init_params=nothing, + initial_params=nothing, initial_state=nothing, kwargs..., ) @@ -324,7 +324,7 @@ function mcmcsample( seeds = rand(rng, UInt, nchains) # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(init_params, nchains) + check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) # Set up a chains vector. @@ -376,10 +376,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if init_params === nothing + initial_params=if initial_params === nothing nothing else - init_params[chainidx] + initial_params[chainidx] end, initial_state=if initial_state === nothing nothing @@ -414,7 +414,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", - init_params=nothing, + initial_params=nothing, initial_state=nothing, kwargs..., ) @@ -429,7 +429,7 @@ function mcmcsample( end # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(init_params, nchains) + check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) # Create a seed for each chain using the provided random number generator. @@ -467,7 +467,7 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, init_params, initial_state) + function sample_chain(seed, initial_params, initial_state) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -478,7 +478,7 @@ function mcmcsample( sampler, N; progress=false, - init_params=init_params, + initial_params=initial_params, initial_state=initial_state, kwargs..., ) @@ -490,7 +490,7 @@ function mcmcsample( return chain end chains = Distributed.pmap( - sample_chain, pool, seeds, init_params, initial_state + sample_chain, pool, seeds, initial_params, initial_state ) finally # Stop updating the progress bar. @@ -512,7 +512,7 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", - init_params=nothing, + initial_params=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -521,14 +521,14 @@ function mcmcsample( end # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(init_params, nchains) + check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) # Sample the chains. - function sample_chain(i, seed, init_params=nothing) + function sample_chain(i, seed, initial_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -539,15 +539,15 @@ function mcmcsample( sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), - init_params=init_params, + initial_params=initial_params, kwargs..., ) end - chains = if init_params === nothing + chains = if initial_params === nothing map(sample_chain, 1:nchains, seeds) else - map(sample_chain, 1:nchains, seeds, init_params) + map(sample_chain, 1:nchains, seeds, initial_params) end # Concatenate the chains together. From 18ad4a54cf517e4877ff6999b5dee22c474ce404 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:20:06 +0100 Subject: [PATCH 08/17] renamed references for init_params to initial_params --- docs/src/api.md | 6 +++--- test/sample.jl | 46 +++++++++++++++++++++++----------------------- test/utils.jl | 6 +++--- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index de956535..aabf8d6f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,9 +82,9 @@ Common keyword arguments for regular and parallel sampling are: The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). There is no "official" way for providing initial parameter values yet. -However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. -To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/test/sample.jl b/test/sample.jl index 22f4b26d..0ede2a13 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -28,7 +28,7 @@ # initial parameters chain = sample( - MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8) ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 @@ -163,7 +163,7 @@ # initial parameters nchains = 100 - init_params = [(b=randn(), a=rand()) for _ in 1:nchains] + initial_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -171,15 +171,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (a=randn(), b=rand()) + initial_params = (a=randn(), b=rand()) chains = sample( MyModel(), MySampler(), @@ -187,14 +187,14 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=Iterators.repeated(initial_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -205,7 +205,7 @@ init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -298,7 +298,7 @@ # initial parameters nchains = 100 - init_params = [(a=randn(), b=rand()) for _ in 1:nchains] + initial_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -306,15 +306,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=randn(), a=rand()) + initial_params = (b=randn(), a=rand()) chains = sample( MyModel(), MySampler(), @@ -326,10 +326,10 @@ ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -340,7 +340,7 @@ init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -407,7 +407,7 @@ # initial parameters nchains = 100 - init_params = [(a=rand(), b=randn()) for _ in 1:nchains] + initial_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -415,15 +415,15 @@ 3, nchains; progress=false, - init_params=init_params, + initial_params=initial_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, init_params) + (chain, params) in zip(chains, initial_params) ) - init_params = (b=rand(), a=randn()) + initial_params = (b=rand(), a=randn()) chains = sample( MyModel(), MySampler(), @@ -435,10 +435,10 @@ ) @test length(chains) == nchains @test all( - chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains ) - # Too many `init_params` + # Too many `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -449,7 +449,7 @@ init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `init_params` + # Too few `initial_params` @test_throws ArgumentError sample( MyModel(), MySampler(), diff --git a/test/utils.jl b/test/utils.jl index f69fcdab..1e29a473 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,12 +22,12 @@ function AbstractMCMC.step( sampler::MySampler, state::Union{Nothing,Integer}=nothing; loggers=false, - init_params=nothing, + initial_params=nothing, kwargs..., ) # sample `a` is missing in the first step if not provided - a, b = if state === nothing && init_params !== nothing - init_params.a, init_params.b + a, b = if state === nothing && initial_params !== nothing + initial_params.a, initial_params.b else (state === nothing ? missing : rand(rng)), randn(rng) end From a1c188159627d80b88b1d6a8074d992bd9a1d834 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:24:12 +0100 Subject: [PATCH 09/17] formatting --- test/sample.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index 0ede2a13..c6842b6f 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -191,7 +191,8 @@ ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) # Too many `initial_params` @@ -326,7 +327,8 @@ ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) # Too many `initial_params` @@ -435,7 +437,8 @@ ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for chain in chains + chain[1].a == initial_params.a && chain[1].b == initial_params.b for + chain in chains ) # Too many `initial_params` From ca4f4b9940a14ec952f06b1ffd60173685856620 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Oct 2023 18:30:09 +0100 Subject: [PATCH 10/17] initial_state missing from one mcmcsample --- src/sample.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sample.jl b/src/sample.jl index ddfad4f8..b40d447e 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -513,6 +513,7 @@ function mcmcsample( nchains::Integer; progressname="Sampling", initial_params=nothing, + initial_state=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples From e442880e71fd612f1985a3740e202cbbb8e81411 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 23:34:23 +0100 Subject: [PATCH 11/17] fixed initial_params and initial_state for MCMCDistributed --- Project.toml | 2 ++ src/AbstractMCMC.jl | 1 + src/sample.jl | 5 ++++- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 90117048..a796c429 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ version = "4.5.0" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" +FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 64f20f97..dc464d42 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging using StatsBase: StatsBase using TerminalLoggers: TerminalLoggers using Transducers: Transducers +using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging diff --git a/src/sample.jl b/src/sample.jl index b40d447e..a15655bf 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -432,6 +432,9 @@ function mcmcsample( check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) + _initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -490,7 +493,7 @@ function mcmcsample( return chain end chains = Distributed.pmap( - sample_chain, pool, seeds, initial_params, initial_state + sample_chain, pool, seeds, _initial_params, _initial_state ) finally # Stop updating the progress bar. From fb4a1f613c1f4818e09252c6d98ee21a53a1b65e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 23:34:47 +0100 Subject: [PATCH 12/17] fixed typo in the initial step --- src/sample.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sample.jl b/src/sample.jl index a15655bf..7b4ff830 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -126,7 +126,7 @@ function mcmcsample( sample, state = if initial_state === nothing step(rng, model, sampler; kwargs...) else - step(rng, model, sampler, state; kwargs...) + step(rng, model, sampler, initial_state; kwargs...) end # Discard initial samples. From c6ec64e898f239137572f7510b90091af2087137 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 23:35:19 +0100 Subject: [PATCH 13/17] replaced init_params with initial_params in tests --- test/sample.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index c6842b6f..197c9003 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -187,7 +187,7 @@ 3, nchains; progress=false, - initial_params=Iterators.repeated(initial_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( @@ -203,7 +203,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) # Too few `initial_params` @@ -214,7 +214,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) end @@ -323,7 +323,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( @@ -339,7 +339,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) # Too few `initial_params` @@ -350,7 +350,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) # Remove workers @@ -433,7 +433,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + initial_params=FillArrays.Fill(initial_params, nchains), ) @test length(chains) == nchains @test all( @@ -449,7 +449,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), + initial_params=FillArrays.Fill(initial_params, nchains + 1), ) # Too few `initial_params` @@ -460,7 +460,7 @@ 3, nchains; progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), + initial_params=FillArrays.Fill(initial_params, nchains - 1), ) end From 74465b4df6da581c22bb46c7ff83a20f6ad97619 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Oct 2023 23:35:33 +0100 Subject: [PATCH 14/17] disabled logging for large number of chains in tests where logging isnt tested --- test/sample.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index 197c9003..8cacd3e7 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -360,13 +360,13 @@ @testset "Serial sampling" begin # No dedicated chains type N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; progress=false) @test chains isa Vector{<:Vector{<:MySample}} @test length(chains) == 1000 @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -382,7 +382,7 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) @test all(ismissing(c.as[1]) for c in chains2) @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) From 56456c93644742efb73a2d58d53b8e924c04487d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 00:10:51 +0100 Subject: [PATCH 15/17] mroe fixes --- src/sample.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 7b4ff830..58c650e3 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -528,11 +528,14 @@ function mcmcsample( check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) + _initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) # Sample the chains. - function sample_chain(i, seed, initial_params=nothing) + function sample_chain(i, seed, initial_params, initial_state) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -544,15 +547,12 @@ function mcmcsample( N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), initial_params=initial_params, + initial_state=initial_state, kwargs..., ) end - chains = if initial_params === nothing - map(sample_chain, 1:nchains, seeds) - else - map(sample_chain, 1:nchains, seeds, initial_params) - end + chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state) # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) From 5d83ab406213b0c2009a84d214c4d8722d0c8e89 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 00:10:58 +0100 Subject: [PATCH 16/17] added tests for initial state --- test/sample.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/test/sample.jl b/test/sample.jl index 8cacd3e7..295c830e 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -658,4 +658,58 @@ ) @test it_array == collect(1:size(chain, 1)) end + + @testset "Providing initial state" begin + function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...) + put!(states_channel, state) + end + + initial_state = 10 + + @testset "sample" begin + n = 10 + states_channel = Channel{Int}(n) + chain = sample( + MyModel(), MySampler(), n; + initial_state=initial_state, + callback=record_state, + states_channel=states_channel + ) + + # Extract the states. + states = [take!(states_channel) for _ in 1:n] + @test length(states) == n + for i in 1:n + @test states[i] == initial_state + i + end + end + + @testset "sample with $mode" for mode in [ + MCMCSerial(), + MCMCThreads(), + MCMCDistributed(), + ] + nchains = 4 + initial_state = 10 + states_channel = if mode === MCMCDistributed() + # Need to use `RemoteChannel` for this. + RemoteChannel(() -> Channel{Int}(nchains)) + else + Channel{Int}(nchains) + end + chain = sample( + MyModel(), MySampler(), mode, 1, nchains; + initial_state=FillArrays.Fill(initial_state, nchains), + callback=record_state, + states_channel=states_channel + ) + + # Extract the states. + states = [take!(states_channel) for _ in 1:nchains] + @test length(states) == nchains + for i = 1:nchains + @test states[i] == initial_state + 1 + end + end + end end From 3ed53147802e80d674866499d5d400a8c12f60f1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 3 Oct 2023 00:33:02 +0100 Subject: [PATCH 17/17] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/sample.jl | 12 ++++++++---- test/sample.jl | 49 +++++++++++++++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index 58c650e3..58217f39 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -432,8 +432,10 @@ function mcmcsample( check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) - _initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + _initial_params = + initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = + initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -528,8 +530,10 @@ function mcmcsample( check_initial_params(initial_params, nchains) check_initial_state(initial_state, nchains) - _initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + _initial_params = + initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params + _initial_state = + initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) diff --git a/test/sample.jl b/test/sample.jl index 295c830e..dcc87526 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -366,7 +366,15 @@ @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + N, + 1000; + chain_type=MyChain, + progress=false, + ) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -382,7 +390,15 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false) + chains2 = sample( + MyModel(), + MySampler(), + MCMCSerial(), + N, + 1000; + chain_type=MyChain, + progress=false, + ) @test all(ismissing(c.as[1]) for c in chains2) @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @@ -660,8 +676,10 @@ end @testset "Providing initial state" begin - function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...) - put!(states_channel, state) + function record_state( + rng, model, sampler, sample, state, i; states_channel, kwargs... + ) + return put!(states_channel, state) end initial_state = 10 @@ -670,10 +688,12 @@ n = 10 states_channel = Channel{Int}(n) chain = sample( - MyModel(), MySampler(), n; + MyModel(), + MySampler(), + n; initial_state=initial_state, callback=record_state, - states_channel=states_channel + states_channel=states_channel, ) # Extract the states. @@ -684,11 +704,8 @@ end end - @testset "sample with $mode" for mode in [ - MCMCSerial(), - MCMCThreads(), - MCMCDistributed(), - ] + @testset "sample with $mode" for mode in + [MCMCSerial(), MCMCThreads(), MCMCDistributed()] nchains = 4 initial_state = 10 states_channel = if mode === MCMCDistributed() @@ -698,16 +715,20 @@ Channel{Int}(nchains) end chain = sample( - MyModel(), MySampler(), mode, 1, nchains; + MyModel(), + MySampler(), + mode, + 1, + nchains; initial_state=FillArrays.Fill(initial_state, nchains), callback=record_state, - states_channel=states_channel + states_channel=states_channel, ) # Extract the states. states = [take!(states_channel) for _ in 1:nchains] @test length(states) == nchains - for i = 1:nchains + for i in 1:nchains @test states[i] == initial_state + 1 end end