-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow specification of initial state for sample
#119
Changes from all commits
34eecab
fbd24ca
e9bee36
207c77d
b67e51d
e63c46e
5083d94
18ad4a5
a1c1881
ca4f4b9
e442880
fb4a1f6
c6ec64e
74465b4
56456c9
5d83ab4
3ed5314
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -103,6 +103,7 @@ | |
discard_initial=0, | ||
thinning=1, | ||
chain_type::Type=Any, | ||
initial_state=nothing, | ||
kwargs..., | ||
) | ||
# Check the number of requested samples. | ||
|
@@ -122,7 +123,11 @@ | |
end | ||
|
||
# Obtain the initial sample and state. | ||
sample, state = step(rng, model, sampler; kwargs...) | ||
sample, state = if initial_state === nothing | ||
step(rng, model, sampler; kwargs...) | ||
else | ||
step(rng, model, sampler, initial_state; kwargs...) | ||
end | ||
|
||
# Discard initial samples. | ||
for i in 1:discard_initial | ||
|
@@ -211,6 +216,7 @@ | |
callback=nothing, | ||
discard_initial=0, | ||
thinning=1, | ||
initial_state=nothing, | ||
kwargs..., | ||
) | ||
|
||
|
@@ -220,7 +226,11 @@ | |
|
||
@ifwithprogresslogger progress name = progressname begin | ||
# Obtain the initial sample and state. | ||
sample, state = step(rng, model, sampler; kwargs...) | ||
sample, state = if initial_state === nothing | ||
step(rng, model, sampler; kwargs...) | ||
else | ||
step(rng, model, sampler, state; kwargs...) | ||
end | ||
|
||
# Discard initial samples. | ||
for _ in 1:discard_initial | ||
|
@@ -288,7 +298,8 @@ | |
nchains::Integer; | ||
progress=PROGRESS[], | ||
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", | ||
init_params=nothing, | ||
initial_params=nothing, | ||
initial_state=nothing, | ||
kwargs..., | ||
) | ||
# Check if actually multiple threads are used. | ||
|
@@ -312,8 +323,9 @@ | |
# Create a seed for each chain using the provided random number generator. | ||
seeds = rand(rng, UInt, nchains) | ||
|
||
# Ensure that initial parameters are `nothing` or of the correct length | ||
check_initial_params(init_params, nchains) | ||
# Ensure that initial parameters and states are `nothing` or of the correct length | ||
check_initial_params(initial_params, nchains) | ||
check_initial_state(initial_state, nchains) | ||
|
||
# Set up a chains vector. | ||
chains = Vector{Any}(undef, nchains) | ||
|
@@ -364,10 +376,15 @@ | |
_sampler, | ||
N; | ||
progress=false, | ||
init_params=if init_params === nothing | ||
initial_params=if initial_params === nothing | ||
nothing | ||
else | ||
initial_params[chainidx] | ||
end, | ||
initial_state=if initial_state === nothing | ||
nothing | ||
else | ||
init_params[chainidx] | ||
initial_state[chainidx] | ||
end, | ||
kwargs..., | ||
) | ||
|
@@ -397,7 +414,8 @@ | |
nchains::Integer; | ||
progress=PROGRESS[], | ||
progressname="Sampling ($(Distributed.nworkers()) processes)", | ||
init_params=nothing, | ||
initial_params=nothing, | ||
initial_state=nothing, | ||
kwargs..., | ||
) | ||
# Check if actually multiple processes are used. | ||
|
@@ -410,8 +428,14 @@ | |
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" | ||
end | ||
|
||
# Ensure that initial parameters are `nothing` or of the correct length | ||
check_initial_params(init_params, nchains) | ||
# Ensure that initial parameters and states are `nothing` or of the correct length | ||
check_initial_params(initial_params, nchains) | ||
check_initial_state(initial_state, nchains) | ||
|
||
_initial_params = | ||
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params | ||
_initial_state = | ||
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state | ||
Comment on lines
+435
to
+438
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be nice to avoid these
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is adding in callable structs here really an improvement? 😕 I agree it's more efficient, but it seems like this will be quite a bit more complex + the efficiency doesn't really matter here, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A callable struct should generally be better for the compiler than a closure, shouldn't it? Regardless of whether we change or add arguments as in this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep! But is this performance critical code? And it seems to be me that we'll need a callable struct for each scenario? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤷 Are you sure? To me it seems one struct is sufficient - both the multithreaded and the multicore version seem to use the same inner structure, and in the serial case we could set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
To clarify, I don't question whether we can have a single callable struct with different call implementations; I meant more that it seems it won't be as simple as just doing struct SampleFunc
# ...
end
function (f::SampleFunc)(args...)
# ...
end
But if we put Or am I misunderstanding what you mean here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just bumping this discussion:) Would be nice to get a version of this PR merged. |
||
|
||
# Create a seed for each chain using the provided random number generator. | ||
seeds = rand(rng, UInt, nchains) | ||
|
@@ -448,7 +472,7 @@ | |
|
||
Distributed.@async begin | ||
try | ||
function sample_chain(seed, init_params=nothing) | ||
function sample_chain(seed, initial_params, initial_state) | ||
# Seed a new random number generator with the pre-made seed. | ||
Random.seed!(rng, seed) | ||
|
||
|
@@ -459,7 +483,8 @@ | |
sampler, | ||
N; | ||
progress=false, | ||
init_params=init_params, | ||
initial_params=initial_params, | ||
initial_state=initial_state, | ||
kwargs..., | ||
) | ||
|
||
|
@@ -469,11 +494,9 @@ | |
# Return the new chain. | ||
return chain | ||
end | ||
chains = if init_params === nothing | ||
Distributed.pmap(sample_chain, pool, seeds) | ||
else | ||
Distributed.pmap(sample_chain, pool, seeds, init_params) | ||
end | ||
chains = Distributed.pmap( | ||
sample_chain, pool, seeds, _initial_params, _initial_state | ||
) | ||
finally | ||
# Stop updating the progress bar. | ||
progress && put!(channel, false) | ||
|
@@ -494,22 +517,29 @@ | |
N::Integer, | ||
nchains::Integer; | ||
progressname="Sampling", | ||
init_params=nothing, | ||
initial_params=nothing, | ||
initial_state=nothing, | ||
kwargs..., | ||
) | ||
# Check if the number of chains is larger than the number of samples | ||
if nchains > N | ||
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" | ||
end | ||
|
||
# Ensure that initial parameters are `nothing` or of the correct length | ||
check_initial_params(init_params, nchains) | ||
# Ensure that initial parameters and states are `nothing` or of the correct length | ||
check_initial_params(initial_params, nchains) | ||
check_initial_state(initial_state, nchains) | ||
|
||
_initial_params = | ||
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params | ||
_initial_state = | ||
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state | ||
|
||
# Create a seed for each chain using the provided random number generator. | ||
seeds = rand(rng, UInt, nchains) | ||
|
||
# Sample the chains. | ||
function sample_chain(i, seed, init_params=nothing) | ||
function sample_chain(i, seed, initial_params, initial_state) | ||
# Seed a new random number generator with the pre-made seed. | ||
Random.seed!(rng, seed) | ||
|
||
|
@@ -520,16 +550,13 @@ | |
sampler, | ||
N; | ||
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), | ||
init_params=init_params, | ||
initial_params=initial_params, | ||
initial_state=initial_state, | ||
kwargs..., | ||
) | ||
end | ||
|
||
chains = if init_params === nothing | ||
map(sample_chain, 1:nchains, seeds) | ||
else | ||
map(sample_chain, 1:nchains, seeds, init_params) | ||
end | ||
chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state) | ||
|
||
# Concatenate the chains together. | ||
return chainsstack(tighten_eltype(chains)) | ||
|
@@ -543,7 +570,6 @@ | |
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", | ||
), | ||
) | ||
|
||
check_initial_params(::Nothing, n) = nothing | ||
function check_initial_params(x::AbstractArray, n) | ||
if length(x) != n | ||
|
@@ -556,3 +582,21 @@ | |
|
||
return nothing | ||
end | ||
|
||
@nospecialize check_initial_state(x, n) = throw( | ||
ArgumentError( | ||
"initial states must be specified as a vector of length equal to the number of chains or `nothing`", | ||
), | ||
) | ||
check_initial_state(::Nothing, n) = nothing | ||
function check_initial_state(x::AbstractArray, n) | ||
if length(x) != n | ||
throw( | ||
ArgumentError( | ||
"incorrect number of initial states (expected $n, received $(length(x))" | ||
), | ||
) | ||
end | ||
|
||
return nothing | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means it could be removed from the [extras] section below. But do we actually have to depend on FillArrays? I think we should just forward the user-input or
nothing
, but not build any arrays explicitly?