Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored Oct 2, 2023
1 parent 5d83ab4 commit 3ed5314
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
12 changes: 8 additions & 4 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,10 @@ function mcmcsample(
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -528,8 +530,10 @@ function mcmcsample(
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params = initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state = initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state
_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down
49 changes: 35 additions & 14 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,15 @@
@test all(length(x) == N for x in chains)

Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false)
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
N,
1000;
chain_type=MyChain,
progress=false,
)

# Test output type and size.
@test chains isa Vector{<:MyChain}
Expand All @@ -382,7 +390,15 @@

# Test reproducibility.
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain, progress=false)
chains2 = sample(
MyModel(),
MySampler(),
MCMCSerial(),
N,
1000;
chain_type=MyChain,
progress=false,
)
@test all(ismissing(c.as[1]) for c in chains2)
@test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N)
@test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
Expand Down Expand Up @@ -660,8 +676,10 @@
end

@testset "Providing initial state" begin
function record_state(rng, model, sampler, sample, state, i; states_channel, kwargs...)
put!(states_channel, state)
function record_state(
rng, model, sampler, sample, state, i; states_channel, kwargs...
)
return put!(states_channel, state)
end

initial_state = 10
Expand All @@ -670,10 +688,12 @@
n = 10
states_channel = Channel{Int}(n)
chain = sample(
MyModel(), MySampler(), n;
MyModel(),
MySampler(),
n;
initial_state=initial_state,
callback=record_state,
states_channel=states_channel
states_channel=states_channel,
)

# Extract the states.
Expand All @@ -684,11 +704,8 @@
end
end

@testset "sample with $mode" for mode in [
MCMCSerial(),
MCMCThreads(),
MCMCDistributed(),
]
@testset "sample with $mode" for mode in
[MCMCSerial(), MCMCThreads(), MCMCDistributed()]
nchains = 4
initial_state = 10
states_channel = if mode === MCMCDistributed()
Expand All @@ -698,16 +715,20 @@
Channel{Int}(nchains)
end
chain = sample(
MyModel(), MySampler(), mode, 1, nchains;
MyModel(),
MySampler(),
mode,
1,
nchains;
initial_state=FillArrays.Fill(initial_state, nchains),
callback=record_state,
states_channel=states_channel
states_channel=states_channel,
)

# Extract the states.
states = [take!(states_channel) for _ in 1:nchains]
@test length(states) == nchains
for i = 1:nchains
for i in 1:nchains
@test states[i] == initial_state + 1
end
end
Expand Down

0 comments on commit 3ed5314

Please sign in to comment.