-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support init_params
in ensemble methods
#94
Conversation
Codecov Report
@@ Coverage Diff @@
## master #94 +/- ##
==========================================
- Coverage 97.74% 97.48% -0.26%
==========================================
Files 7 7
Lines 222 239 +17
==========================================
+ Hits 217 233 +16
- Misses 5 6 +1
Continue to review full report at Codecov.
|
The MCMCChains errors are unrelated and known. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be a documentation update somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
src/sample.jl
Outdated
return StatsBase.sample( | ||
rng, model, sampler, N; | ||
progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), | ||
init_params = init_params === nothing ? nothing : init_params[i], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could fail both silently and loudly if I called
sample(..., MCMCThreads, N, 4, init_params=[1,2,3])
expecting init_params
to be replicated across all the chains, which is roughly what the current behavior is in Turing/DynamicPPL.
The problem is that I don't see a great fix other than telling users to just to replicate their initializers. Perhaps we could also have a keyword replicate_init=true/false
that adds the following to the preamble in sample
:
if replicate_init
init_params = [init_params for _ in 1:nchains]
end
but this seems clumsy to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the only unambiguous and clean approach is to use a collection of parameters. And to use eg. FillArrays
if you want to use the same parameter for every chain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With a boolean switch one would have to rely on constant propagation for type inference. It's an option though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In any case IMO we should not use some heuristic for trying to determine if the provided init_params
corresponds to a single set of values or multiple ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, then if we want to keep it as-is we should do a major release because the behavior is breaking, at least for Turing folks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO users should just provide an indexable collection - if desired they could even use first
. I don't think we should use it internally since it will always copy and collect even if not needed (https://github.com/JuliaLang/julia/blob/bf534986350a991e4a1b29126de0342ffd76205e/base/abstractarray.jl#L448-453). Additonally, it would require to drop Julia < 1.6 or copy code for older versions - and given the unclear benefit (it's not clear what problem it would solve exactly) I don't think there's a compelling reason to do any of these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the better approach if we care about non-indexable collections is to rewrite the serial and threaded implementation. pmap
already supports iterators and the serial code can be rewritten in the same way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine to require the collection be indexed, so long as it's documented, which it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the implementation - now all collections are supported with minor overhead in multi-threaded sampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And Iterators.repeated
is not handled in a special way.
docs/src/api.md
Outdated
There is no "official" way for providing initial parameter values yet. | ||
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. | ||
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): | ||
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain. | |
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th chain is sampled with keyword argument `init_params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember why I used params
- I actually wanted to write
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain. | |
- `init_params` (default: `nothing`): if set to `params !== nothing`, then the `i`th chain is sampled with keyword argument `init_params = params[i]`. If `init_params = Iterators.repeated(x)`, then the initial parameters `x` are used for every chain. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the formulation.
#99 should be merged first since it is non-breaking. |
Implements the suggestion in #92 (comment).
With this PR, samplers that support
init_params
for single chain sampling (such as e.g. AdvancedMH, EllipticalSliceSampling or Turing) get support forinit_params
with ensemble methods for free. E.g., with this PR and the latest release of EllipticalSliceSampling (no changes needed):Samplers that do not support
init_params
are not affected by this PR.