Skip to content

Commit

Permalink
fix PMMH
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Mar 3, 2019
1 parent 7c911fd commit cedccdc
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
86 changes: 54 additions & 32 deletions src/inference/pmmh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mutable struct PMMH{space, A<:Tuple} <: InferenceAlgorithm
algs :: A # Proposals for state & parameters
gid :: Int # group ID
end
PMMH{space}(n, algs, gid) where space = PMMH{space, typeof(algs)}(n, algs, gid)
function PMMH(n_iters::Int, smc_alg::SMC, parameter_algs...)
algs = tuple(parameter_algs..., smc_alg)
return PMMH{buildspace(algs)}(n_iters, algs, 0)
Expand All @@ -38,22 +39,26 @@ function PIMH(n_iters::Int, smc_alg::SMC)
end

@inline function get_pmmh_samplers(subalgs, model, n, alg, alg_str)
if length(subalgs) == 0
return ()
else
subalg = subalgs[1]
if typeof(subalg) == MH && subalg.n_iters != 1
warn("[$alg_str] number of iterations greater than 1 is useless for MH since it is only used for its proposal")
end
if isa(subalg, Union{SMC, MH})
return (Sampler(typeof(subalg)(subalg, n + 1 - length(subalgs)), model), get_pmmh_samplers(Base.tail(subalgs), model, n, alg, alg_str)...)
else
error("[$alg_str] unsupport base sampling algorithm $alg")
end
end
end
if length(subalgs) == 0
return (), ()
else
subalg = subalgs[1]
if typeof(subalg) == MH && subalg.n_iters != 1
warn("[$alg_str] number of iterations greater than 1 is useless for MH since it is only used for its proposal")
end
if isa(subalg, Union{SMC, MH})
spl, vi = init_spl(model, typeof(subalg)(subalg, n + 1 - length(subalgs)))
spls_vis = get_pmmh_samplers(Base.tail(subalgs), model, n, alg, alg_str)
spls = (spl, spls_vis[1]...)
vis = (vi, spls_vis[2]...)
return spls, vis
else
error("[$alg_str] unsupport base sampling algorithm $alg")
end
end
end

struct PMMHInfo{Tsamplers}
mutable struct PMMHInfo{Tsamplers}
samplers::Tsamplers
violating_support::Bool
prior_prob::Float64
Expand All @@ -62,41 +67,58 @@ struct PMMHInfo{Tsamplers}
old_prior_prob::Float64
progress::ProgressMeter.Progress
end
function PMMHInfo(samplers, n = 0)
function PMMHInfo(samplers, alg::PMMH)
n = alg.n_iters
return PMMHInfo(samplers, false, 0.0, 0.0, -Inf, 0.0, ProgressMeter.Progress(n, 1, "[PMMH] Sampling...", 0))
end

function Sampler(alg::PMMH, model::Model)
alg_str = "PMMH"
n_samplers = length(alg.algs)
samplers = get_pmmh_samplers(alg.algs, model, n_samplers, alg, alg_str)
samplers, vis = get_pmmh_samplers(alg.algs, model, n_samplers, alg, alg_str)
verifyspace(alg.algs, model.pvars, alg_str)
info = PMMHInfo(samplers)
info = PMMHInfo(samplers, alg)
return Sampler(alg, info)
end

function step(model, spl::Sampler{<:PMMH}, vi::AbstractVarInfo, is_first::Bool)
function init_spl(model, alg::PMMH; kwargs...)
spl = Sampler(alg, model)
vi = VarInfo(model)
return spl, vi
end

@inline function _step(samplers::Tuple, model, vi, violating_support, new_prior_prob, proposal_ratio)
if length(samplers) == 1
return violating_support, new_prior_prob, proposal_ratio
end
local_spl = samplers[1]
propose(model, local_spl, vi)
Turing.DEBUG && @debug "$(typeof(local_spl)) proposing $(getspace(local_spl))..."
if local_spl.info.violating_support
violating_support = true
return violating_support, new_prior_prob, proposal_ratio
end
new_prior_prob += local_spl.info.prior_prob
proposal_ratio += local_spl.info.proposal_ratio

return _step(Base.tail(samplers), model, vi, violating_support, new_prior_prob, proposal_ratio)
end

function step(model, spl::Sampler{<:PMMH}, vi::AbstractVarInfo)
violating_support = false
proposal_ratio = 0.0
new_prior_prob = 0.0
new_likelihood_estimate = 0.0
old_θ = copy(vi[spl])

Turing.DEBUG && @debug "Propose new parameters from proposals..."
for local_spl in spl.info.samplers[1:end-1]
Turing.DEBUG && @debug "$(typeof(local_spl)) proposing $(getspace(local_spl))..."
propose(model, local_spl, vi)
if local_spl.info.violating_support
violating_support = true
break
end
new_prior_prob += local_spl.info.prior_prob
proposal_ratio += local_spl.info.proposal_ratio
end

violating_support, new_prior_prob, proposal_ratio =
_step(spl.info.samplers, model, vi, violating_support, new_prior_prob, proposal_ratio)

if !violating_support # do not run SMC if going to refuse anyway
Turing.DEBUG && @debug "Propose new state with SMC..."
vi = step(model, spl.info.samplers[end], vi)
vi, _ = step(model, spl.info.samplers[end], vi)
new_likelihood_estimate = spl.info.samplers[end].info.logevidence[end]

Turing.DEBUG && @debug "computing accept rate α..."
Expand Down Expand Up @@ -137,7 +159,7 @@ function _sample(vi, samples, spl, model, alg::PMMH;
PROGRESS[] && (spl.info.progress = ProgressMeter.Progress(n, 1, "[$alg_str] Sampling...", 0))
for i = 1:n
Turing.DEBUG && @debug "$alg_str stepping..."
time_elapsed = @elapsed vi, is_accept = step(model, spl, vi, i==1)
time_elapsed = @elapsed vi, is_accept = step(model, spl, vi)

if is_accept # accepted => store the new predcits
samples[i].value = Sample(vi, spl).value
Expand All @@ -148,7 +170,7 @@ function _sample(vi, samples, spl, model, alg::PMMH;
time_total += time_elapsed
push!(accept_his, is_accept)
if PROGRESS[]
haskey(spl.info, :progress) && ProgressMeter.update!(spl.info.progress, spl.info.progress.counter + 1)
isdefined(spl.info, :progress) && ProgressMeter.update!(spl.info.progress, spl.info.progress.counter + 1)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/inference/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ function Sample(vi::AbstractVarInfo, spl::Sampler)
return s
end

using InteractiveUtils
#using InteractiveUtils

function init_samples(alg, vi::AbstractArray{<:AbstractVarInfo}; kwargs...)
return init_samples(alg, first(vi); kwargs...)
Expand Down
10 changes: 8 additions & 2 deletions src/inference/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ function SMC(n_particles::Int, space...)
F = typeof(resample_systematic)
return SMC{space, F}(n_particles, resample_systematic, 0.5, 0)
end
SMC{Ts, Tf}(alg::SMC{Ts, Tf}, new_gid::Int) where {Ts, Tf} = SMC(alg, new_gid)
function SMC(alg::SMC{space, F}, new_gid::Int) where {space, F}
return SMC{space, F}(alg.n_particles, alg.resampler, alg.resampler_threshold, new_gid)
end
Expand All @@ -53,6 +54,12 @@ function Sampler(alg::SMC)
Sampler(alg, info)
end

function init_spl(model, alg::SMC)
vi = VarInfo(model)
spl = Sampler(alg)
return spl, vi
end

function step(model, spl::Sampler{<:SMC}, vi::AbstractVarInfo)
particles = ParticleContainer{Trace{typeof(vi)}}(model)
vi.num_produce = 0; # Reset num_produce before new sweep\.
Expand Down Expand Up @@ -80,8 +87,7 @@ VarInfo(model::Model) = TypedVarInfo(default_varinfo(model))

## wrapper for smc: run the sampler, collect results.
function sample(model::Model, alg::SMC)
spl = Sampler(alg)
vi = VarInfo(model)
spl, vi = init_spl(model, alg)
particles = ParticleContainer{Trace{typeof(vi)}}(model)
push!(particles, spl.alg.n_particles, spl, vi)

Expand Down

0 comments on commit cedccdc

Please sign in to comment.