Skip to content

Commit

Permalink
Add a special handler for existing spl=nothing uses.
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Sep 13, 2018
1 parent 20f6ecb commit 2881f04
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
4 changes: 4 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ Robust initialization method for model parameters in Hamiltonian samplers.
struct HamiltonianRobustInit <: AbstractSampler end
struct SampleFromPrior <: AbstractSampler end

# This can be removed when all `spl=nothing` is replaced with
# `spl=SampleFromPrior`
const AnySampler = Union{Nothing, AbstractSampler}

include("utilities/helper.jl")
include("utilities/transform.jl")
include("core/varinfo.jl") # core internal variable container
Expand Down
4 changes: 2 additions & 2 deletions src/core/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ macro model(fexpr)
:kwargs => [],
:args => [
:(vi::Turing.VarInfo),
:(sampler::Turing.AbstractSampler)
:(sampler::Turing.AnySampler)
],
:body => body
)
Expand All @@ -272,7 +272,7 @@ macro model(fexpr)
alias2 = MacroTools.combinedef(
Dict(
:name => compiler[:closure_name],
:args => [:(sampler::Turing.AbstractSampler)],
:args => [:(sampler::Turing.AnySampler)],
:kwargs => [],
:body => :(return $(compiler[:closure_name])(Turing.VarInfo(), Turing.SampleFromPrior()))

Expand Down
11 changes: 7 additions & 4 deletions src/samplers/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ observe(spl::Sampler, weight::Float64) =
error("Turing.observe: unmanaged inference algorithm: $(typeof(spl))")

## Default definitions for assume, observe, when sampler = nothing.
## Note: `A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}` can be
## simplified into `A<:Union{SampleFromPrior, HamiltonianRobustInit}` when
## all `spl=nothing` is replaced with `spl=SampleFromPrior`.
function assume(spl::A,
dist::Distribution,
vn::VarName,
vi::VarInfo) where {A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}}
vi::VarInfo) where {A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}}

if haskey(vi, vn)
r = vi[vn]
Expand All @@ -57,7 +60,7 @@ function assume(spl::A,
dists::Vector{T},
vn::VarName,
var::Any,
vi::VarInfo) where {T<:Distribution, A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}}
vi::VarInfo) where {T<:Distribution, A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}}

@assert length(dists) == 1 "Turing.assume only support vectorizing i.i.d distribution"
dist = dists[1]
Expand Down Expand Up @@ -103,7 +106,7 @@ end
function observe(spl::A,
dist::Distribution,
value::Any,
vi::VarInfo) where {A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}}
vi::VarInfo) where {A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}}

vi.num_produce += 1
@debug "dist = $dist"
Expand All @@ -117,7 +120,7 @@ end
function observe(spl::A,
dists::Vector{T},
value::Any,
vi::VarInfo) where {T<:Distribution, A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}}
vi::VarInfo) where {T<:Distribution, A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}}

@assert length(dists) == 1 "Turing.observe only support vectorizing i.i.d distribution"
dist = dists[1]
Expand Down
8 changes: 4 additions & 4 deletions test/compiler.jl/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,23 @@ alias1 = Dict(
:name => :testmodel_comp_model,
:args => [:(vi::Turing.VarInfo)],
:kwargs => [],
:body => :(return testmodel_comp_model(vi, nothing))
:body => :(return testmodel_comp_model(vi, Turing.SampleFromPrior()))
)
@test c[:alias1] == MacroTools.combinedef(alias1)

alias2 = Dict(
:name => :testmodel_comp_model,
:args => [:(sampler::Turing.AbstractSampler)],
:args => [:(sampler::Turing.AnySampler)],
:kwargs => [],
:body => :(return testmodel_comp_model(Turing.VarInfo(), nothing))
:body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior()))
)
@test c[:alias2] == MacroTools.combinedef(alias2)

alias3 = Dict(
:name => :testmodel_comp_model,
:args => [],
:kwargs => [],
:body => :(return testmodel_comp_model(Turing.VarInfo(), nothing))
:body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior()))
)
@test c[:alias3] == MacroTools.combinedef(alias3)
@test length(c[:closure].args[2].args[2].args) == 6
Expand Down

0 comments on commit 2881f04

Please sign in to comment.