diff --git a/src/Turing.jl b/src/Turing.jl index c172b864af..d6cea7b684 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -98,6 +98,10 @@ Robust initialization method for model parameters in Hamiltonian samplers. struct HamiltonianRobustInit <: AbstractSampler end struct SampleFromPrior <: AbstractSampler end +# This can be removed when all `spl=nothing` is replaced with +# `spl=SampleFromPrior` +const AnySampler = Union{Nothing, AbstractSampler} + include("utilities/helper.jl") include("utilities/transform.jl") include("core/varinfo.jl") # core internal variable container diff --git a/src/core/compiler.jl b/src/core/compiler.jl index 2f7211542c..20036cdf44 100644 --- a/src/core/compiler.jl +++ b/src/core/compiler.jl @@ -252,7 +252,7 @@ macro model(fexpr) :kwargs => [], :args => [ :(vi::Turing.VarInfo), - :(sampler::Turing.AbstractSampler) + :(sampler::Turing.AnySampler) ], :body => body ) @@ -272,7 +272,7 @@ macro model(fexpr) alias2 = MacroTools.combinedef( Dict( :name => compiler[:closure_name], - :args => [:(sampler::Turing.AbstractSampler)], + :args => [:(sampler::Turing.AnySampler)], :kwargs => [], :body => :(return $(compiler[:closure_name])(Turing.VarInfo(), Turing.SampleFromPrior())) diff --git a/src/samplers/sampler.jl b/src/samplers/sampler.jl index 0d0801c1d8..51f398f255 100644 --- a/src/samplers/sampler.jl +++ b/src/samplers/sampler.jl @@ -34,10 +34,13 @@ observe(spl::Sampler, weight::Float64) = error("Turing.observe: unmanaged inference algorithm: $(typeof(spl))") ## Default definitions for assume, observe, when sampler = nothing. +## Note: `A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}` can be +## simplified into `A<:Union{SampleFromPrior, HamiltonianRobustInit}` when +## all `spl=nothing` is replaced with `spl=SampleFromPrior`. function assume(spl::A, dist::Distribution, vn::VarName, - vi::VarInfo) where {A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}} + vi::VarInfo) where {A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}} if haskey(vi, vn) r = vi[vn] @@ -57,7 +60,7 @@ function assume(spl::A, dists::Vector{T}, vn::VarName, var::Any, - vi::VarInfo) where {T<:Distribution, A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}} + vi::VarInfo) where {T<:Distribution, A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}} @assert length(dists) == 1 "Turing.assume only support vectorizing i.i.d distribution" dist = dists[1] @@ -103,7 +106,7 @@ end function observe(spl::A, dist::Distribution, value::Any, - vi::VarInfo) where {A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}} + vi::VarInfo) where {A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}} vi.num_produce += 1 @debug "dist = $dist" @@ -117,7 +120,7 @@ end function observe(spl::A, dists::Vector{T}, value::Any, - vi::VarInfo) where {T<:Distribution, A<:Union{HamiltonianRobustInit, SampleFromPrior, Nothing}} + vi::VarInfo) where {T<:Distribution, A<:Union{Nothing, SampleFromPrior, HamiltonianRobustInit}} @assert length(dists) == 1 "Turing.observe only support vectorizing i.i.d distribution" dist = dists[1] diff --git a/test/compiler.jl/model_macro.jl b/test/compiler.jl/model_macro.jl index c39789f807..a44d2ac887 100644 --- a/test/compiler.jl/model_macro.jl +++ b/test/compiler.jl/model_macro.jl @@ -27,15 +27,15 @@ alias1 = Dict( :name => :testmodel_comp_model, :args => [:(vi::Turing.VarInfo)], :kwargs => [], - :body => :(return testmodel_comp_model(vi, nothing)) + :body => :(return testmodel_comp_model(vi, Turing.SampleFromPrior())) ) @test c[:alias1] == MacroTools.combinedef(alias1) alias2 = Dict( :name => :testmodel_comp_model, - :args => [:(sampler::Turing.AbstractSampler)], + :args => [:(sampler::Turing.AnySampler)], :kwargs => [], - :body => :(return testmodel_comp_model(Turing.VarInfo(), nothing)) + :body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior())) ) @test c[:alias2] == MacroTools.combinedef(alias2) @@ -43,7 +43,7 @@ alias3 = Dict( :name => :testmodel_comp_model, :args => [], :kwargs => [], - :body => :(return testmodel_comp_model(Turing.VarInfo(), nothing)) + :body => :(return testmodel_comp_model(Turing.VarInfo(), Turing.SampleFromPrior())) ) @test c[:alias3] == MacroTools.combinedef(alias3) @test length(c[:closure].args[2].args[2].args) == 6