Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specification of initial state for sample #119

Merged
merged 17 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ version = "4.5.0"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means it could be removed from the [extras] section below. But do we actually have to depend on FillArrays? I think we should just forward the user-input or nothing, but not build any arrays explicitly?

LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand All @@ -21,6 +22,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
FillArrays = "1"
LogDensityProblems = "2"
LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
Expand Down
8 changes: 5 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,16 @@ Common keyword arguments for regular and parallel sampling are:
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
is passed `initial_state` as the `state` argument.

!!! info
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
1 change: 1 addition & 0 deletions src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using ProgressLogging: ProgressLogging
using StatsBase: StatsBase
using TerminalLoggers: TerminalLoggers
using Transducers: Transducers
using FillArrays: FillArrays

using Distributed: Distributed
using Logging: Logging
Expand Down
100 changes: 72 additions & 28 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
discard_initial=0,
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
Expand All @@ -122,7 +123,11 @@
end

# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end

# Discard initial samples.
for i in 1:discard_initial
Expand Down Expand Up @@ -211,6 +216,7 @@
callback=nothing,
discard_initial=0,
thinning=1,
initial_state=nothing,
kwargs...,
)

Expand All @@ -220,7 +226,11 @@

@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = step(rng, model, sampler; kwargs...)
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end

# Discard initial samples.
for _ in 1:discard_initial
Expand Down Expand Up @@ -288,7 +298,8 @@
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple threads are used.
Expand All @@ -312,8 +323,9 @@
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +376,15 @@
_sampler,
N;
progress=false,
init_params=if init_params === nothing
initial_params=if initial_params === nothing
nothing
else
initial_params[chainidx]
end,
initial_state=if initial_state === nothing
nothing
else
init_params[chainidx]
initial_state[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -397,7 +414,8 @@
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) processes)",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if actually multiple processes are used.
Expand All @@ -410,8 +428,14 @@
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
Comment on lines +435 to +438
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be nice to avoid these FillArrays. Maybe we could

  • move the function below to a callable struct (could be shared between all ensemble algorithms maybe?)
  • pass initial_params and initial_state to the constructor as well but only use it to define type parameters that allow us to distinguish between the four possible cases (no initial params and no initial state, only initial state, only initial params, and both initial params and state)
  • define the function of the callable struct depending on the type parameters, forwarding the versions with only the seed or only the seed and one additional argument to the three-argument version

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is adding in callable structs here really an improvement? 😕 I agree it's more efficient, but it seems like this will be quite a bit more complex + the efficiency doesn't really matter here, right?

Copy link
Member

@devmotion devmotion Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A callable struct should generally be better for the compiler than a closure, shouldn't it? Regardless of whether we change or add arguments as in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! But is this performance critical code? And it seems to be me that we'll need a callable struct for each scenario?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷 Are you sure? To me it seems one struct is sufficient - both the multithreaded and the multicore version seem to use the same inner structure, and in the serial case we could set channel = nothing. If needed we could also dispatch on the type of the algorithm to handle minor differences in the function call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me it seems one struct is sufficient

To clarify, I don't question whether we can have a single callable struct with different call implementations; I meant more that it seems it won't be as simple as just doing

struct SampleFunc
    # ...
end

function (f::SampleFunc)(args...)
    # ...
end

multithreaded and the multicore version seem to use the same inner structure

But if we put initial_params and initial_state in the callable struct, then we'll need to pmap, etc. over a range containing the corresponding indices, no? Which seems like it would lead to more allocations than the current impl using Fill(nothing, nchains)?

Or am I misunderstanding what you mean here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just bumping this discussion:) Would be nice to get a version of this PR merged.


# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -448,7 +472,7 @@

Distributed.@async begin
try
function sample_chain(seed, init_params=nothing)
function sample_chain(seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -459,7 +483,8 @@
sampler,
N;
progress=false,
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)

Expand All @@ -469,11 +494,9 @@
# Return the new chain.
return chain
end
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
chains = Distributed.pmap(
sample_chain, pool, seeds, _initial_params, _initial_state
)
finally
# Stop updating the progress bar.
progress && put!(channel, false)
Expand All @@ -494,22 +517,29 @@
N::Integer,
nchains::Integer;
progressname="Sampling",
init_params=nothing,
initial_params=nothing,
initial_state=nothing,
kwargs...,
)
# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)
# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Sample the chains.
function sample_chain(i, seed, init_params=nothing)
function sample_chain(i, seed, initial_params, initial_state)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -520,16 +550,13 @@
sampler,
N;
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
init_params=init_params,
initial_params=initial_params,
initial_state=initial_state,
kwargs...,
)
end

chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
end
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
Expand All @@ -543,7 +570,6 @@
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)

check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
Expand All @@ -556,3 +582,21 @@

return nothing
end

@nospecialize check_initial_state(x, n) = throw(

Check warning on line 586 in src/sample.jl

View check run for this annotation

Codecov / codecov/patch

src/sample.jl#L586

Added line #L586 was not covered by tests
ArgumentError(
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
check_initial_state(::Nothing, n) = nothing
function check_initial_state(x::AbstractArray, n)
if length(x) != n
throw(

Check warning on line 594 in src/sample.jl

View check run for this annotation

Codecov / codecov/patch

src/sample.jl#L594

Added line #L594 was not covered by tests
ArgumentError(
"incorrect number of initial states (expected $n, received $(length(x))"
),
)
end

return nothing
end
Loading
Loading