Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support init_params in ensemble methods #94

Merged
merged 14 commits into from
Mar 7, 2022
Merged

Support init_params in ensemble methods #94

merged 14 commits into from
Mar 7, 2022

Conversation

devmotion
Copy link
Member

@devmotion devmotion commented Feb 9, 2022

Implements the suggestion in #92 (comment).

With this PR, samplers that support init_params for single chain sampling (such as e.g. AdvancedMH, EllipticalSliceSampling or Turing) get support for init_params with ensemble methods for free. E.g., with this PR and the latest release of EllipticalSliceSampling (no changes needed):

julia> using EllipticalSliceSampling, Distributions

julia> prior = Normal(0, 1);
                                                                                                                                              
julia> loglik(x) = logpdf(Normal(x, 0.5), 1.0);                                                                                               

julia> sample(ESSModel(prior, loglik), ESS(), 3; progress=false, init_params=0.5) # `init_params` supported for single chain sampling
3-element Vector{Float64}:
 0.5
 0.5569861002529894
 0.5590132914774115

julia> sample(ESSModel(prior, loglik), ESS(), MCMCSerial(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 1.2604650145668184, 0.5342938313761859]
 [0.4, 0.6766414322931542, 0.6669188733609087]
 [0.2, 0.1991658951820099, 1.620050207287682]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCSerial(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, -0.23546962848230765, 0.28173464269157983]
 [0.0, 0.8450041861878778, 1.2969836926133524]
 [0.0, 0.2803330725356131, 1.5178894526699893]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCThreads(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 0.683876765057849, 0.6000166327171864]
 [0.4, 0.18201765370396472, 0.2882589877099947]
 [0.2, 0.9548385848899827, 0.08087007696252635]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCThreads(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, 0.2830832021306832, -0.11155114222247048]
 [0.0, 0.22734708270317408, 0.908700798723544]
 [0.0, 0.5444555918079804, -0.07055391114874537]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCDistributed(), 3, 3; progress=false, init_params=[0.5, 0.4, 0.2])
3-element Vector{Vector{Float64}}:
 [0.5, 1.3010083566517667, 1.2849239490378277]
 [0.4, 0.5001641558540113, 0.4983091461364374]
 [0.2, 0.3023383923085534, 0.9800014107135103]

julia> sample(ESSModel(prior, loglik), ESS(), MCMCDistributed(), 3, 3; progress=false, init_params=Iterators.repeated(0.0))
3-element Vector{Vector{Float64}}:
 [0.0, -0.01643158241658191, 0.20824023548102533]
 [0.0, 0.46749889505515185, 0.3357075197741447]
 [0.0, 0.6078789038880229, 0.9077006973215425]

Samplers that do not support init_params are not affected by this PR.

@codecov
Copy link

codecov bot commented Feb 9, 2022

Codecov Report

Merging #94 (019bb59) into master (4994a79) will decrease coverage by 0.25%.
The diff coverage is 95.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #94      +/-   ##
==========================================
- Coverage   97.74%   97.48%   -0.26%     
==========================================
  Files           7        7              
  Lines         222      239      +17     
==========================================
+ Hits          217      233      +16     
- Misses          5        6       +1     
Impacted Files Coverage Δ
src/AbstractMCMC.jl 100.00% <ø> (ø)
src/sample.jl 97.66% <95.00%> (-0.40%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4994a79...019bb59. Read the comment docs.

@devmotion
Copy link
Member Author

devmotion commented Feb 9, 2022

The MCMCChains errors are unrelated and known.

Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a documentation update somewhere?

src/sample.jl Show resolved Hide resolved
Copy link
Member

@sethaxen sethaxen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

src/sample.jl Outdated
return StatsBase.sample(
rng, model, sampler, N;
progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
init_params = init_params === nothing ? nothing : init_params[i],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could fail both silently and loudly if I called

sample(..., MCMCThreads, N, 4, init_params=[1,2,3])

expecting init_params to be replicated across all the chains, which is roughly what the current behavior is in Turing/DynamicPPL.

The problem is that I don't see a great fix other than telling users to just to replicate their initializers. Perhaps we could also have a keyword replicate_init=true/false that adds the following to the preamble in sample:

if replicate_init
    init_params = [init_params for _ in 1:nchains]
end

but this seems clumsy to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only unambiguous and clean approach is to use a collection of parameters. And to use eg. FillArrays if you want to use the same parameter for every chain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a boolean switch one would have to rely on constant propagation for type inference. It's an option though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case IMO we should not use some heuristic for trying to determine if the provided init_params corresponds to a single set of values or multiple ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then if we want to keep it as-is we should do a major release because the behavior is breaking, at least for Turing folks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO users should just provide an indexable collection - if desired they could even use first. I don't think we should use it internally since it will always copy and collect even if not needed (https://github.com/JuliaLang/julia/blob/bf534986350a991e4a1b29126de0342ffd76205e/base/abstractarray.jl#L448-453). Additonally, it would require to drop Julia < 1.6 or copy code for older versions - and given the unclear benefit (it's not clear what problem it would solve exactly) I don't think there's a compelling reason to do any of these.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the better approach if we care about non-indexable collections is to rewrite the serial and threaded implementation. pmap already supports iterators and the serial code can be rewritten in the same way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to require the collection be indexed, so long as it's documented, which it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the implementation - now all collections are supported with minor overhead in multi-threaded sampling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Iterators.repeated is not handled in a special way.

docs/src/api.md Outdated
There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain.
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th chain is sampled with keyword argument `init_params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember why I used params - I actually wanted to write

Suggested change
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain.
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `init_params = params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the formulation.

@devmotion
Copy link
Member Author

#99 should be merged first since it is non-breaking.

@yebai yebai merged commit 3de7393 into master Mar 7, 2022
@delete-merged-branch delete-merged-branch bot deleted the dw/init_params branch March 7, 2022 07:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants