diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 0d934a09a..df4628f6f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -116,9 +116,6 @@ export AbstractVarInfo, # Pseudo distributions NamedDist, NoDist, - # Prob macros - @prob_str, - @logprob_str, # Convenience functions logprior, logjoint, @@ -172,7 +169,6 @@ include("varinfo.jl") include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") -include("prob_macro.jl") include("loglikelihoods.jl") include("submodel_macro.jl") include("test_utils.jl") diff --git a/src/prob_macro.jl b/src/prob_macro.jl deleted file mode 100644 index 5e8194c3c..000000000 --- a/src/prob_macro.jl +++ /dev/null @@ -1,244 +0,0 @@ -macro logprob_str(str) - expr1, expr2 = get_exprs(str) - return :(logprob($(esc(expr1)), $(esc(expr2)))) -end -macro prob_str(str) - expr1, expr2 = get_exprs(str) - return :(exp.(logprob($(esc(expr1)), $(esc(expr2))))) -end - -function get_exprs(str::String) - substrings = split(str, '|'; limit=2) - length(substrings) == 2 || error("Invalid expression.") - str1, str2 = substrings - - expr1 = Meta.parse("($str1,)") - expr1 = Expr(:tuple, expr1.args...) - - expr2 = Meta.parse("($str2,)") - expr2 = Expr(:tuple, expr2.args...) - - return expr1, expr2 -end - -function logprob(ex1, ex2) - ptype, model, vi = probtype(ex1, ex2) - if ptype isa Val{:prior} - return logprior(ex1, ex2, model, vi) - elseif ptype isa Val{:likelihood} - return loglikelihood(ex1, ex2, model, vi) - end -end - -function probtype(ntl::NamedTuple{namesl}, ntr::NamedTuple{namesr}) where {namesl,namesr} - if :chain in namesr - if isdefined(ntr.chain.info, :model) - model = ntr.chain.info.model - elseif isdefined(ntr, :model) - model = ntr.model - else - throw( - "The model is not defined. Please make sure the model is either saved in the chain or passed on the RHS of |.", - ) - end - @assert model isa Model - if isdefined(ntr.chain.info, :vi) - _vi = ntr.chain.info.vi - @assert _vi isa VarInfo - vi = TypedVarInfo(_vi) - elseif isdefined(ntr, :varinfo) - _vi = ntr.varinfo - @assert _vi isa VarInfo - vi = TypedVarInfo(_vi) - else - vi = nothing - end - defaults = model.defaults - @assert all(getargnames(model)) do arg - isdefined(ntl, arg) || - isdefined(ntr, arg) || - isdefined(defaults, arg) && getfield(defaults, arg) !== missing - end - return Val(:likelihood), model, vi - else - @assert isdefined(ntr, :model) - model = ntr.model - @assert model isa Model - if isdefined(ntr, :varinfo) - _vi = ntr.varinfo - @assert _vi isa VarInfo - vi = TypedVarInfo(_vi) - else - vi = nothing - end - return probtype(ntl, ntr, model), model, vi - end -end - -function probtype( - left::NamedTuple{leftnames}, - right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} - defaults = model.defaults - prior_rhs = all( - n -> n in (:model, :varinfo) || n in argnames && getfield(right, n) !== missing, - rightnames, - ) - function get_arg(arg) - if arg in leftnames - return getfield(left, arg) - elseif arg in rightnames - return getfield(right, arg) - elseif arg in defaultnames - return getfield(defaults, arg) - elseif arg in argnames - return getfield(model.args, arg) - else - return nothing - end - end - function valid_arg(arg) - a = get_arg(arg) - return a !== nothing && a !== missing - end - valid_args = all(valid_arg, argnames) - - # Uses the default values for model arguments not provided. - # If no default value exists, use `nothing`. - if prior_rhs - return Val(:prior) - # Uses the default values for model arguments not provided. - # If no default value exists or the default value is missing, then error. - elseif valid_args - return Val(:likelihood) - else - for argname in argnames - if !valid_arg(argname) - throw(ArgumentError(missing_arg_error_msg(argname, get_arg(argname)))) - end - end - end -end - -function missing_arg_error_msg(arg, ::Missing) - return """Variable $arg has a value of `missing`, or is not defined and its default value is `missing`. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" -end -function missing_arg_error_msg(arg, ::Nothing) - return """Variable $arg is not defined and has no default value. Please make sure all the variables are either defined with a value other than `missing` or have a default value other than `missing`.""" -end - -function logprior( - left::NamedTuple, right::NamedTuple, _model::Model, _vi::Union{Nothing,VarInfo} -) - # For model args on the LHS of |, use their passed value but add the symbol to - # model.missings. This will lead to an `assume`/`dot_assume` call for those variables. - # Let `p::PriorContext`. If `p.vars` is `nothing`, `assume` and `dot_assume` will use - # the values of the random variables in the `VarInfo`. If `p.vars` is a `NamedTuple` - # or a `Chain`, the value in `p.vars` is input into the `VarInfo` and used instead. - - # For model args not on the LHS of |, if they have a default value, use that, - # otherwise use `nothing`. This will lead to an `observe`/`dot_observe`call for - # those variables. - # All `observe` and `dot_observe` calls are no-op in the PriorContext - - # When all of model args are on the lhs of |, this is also equal to the logjoint. - model = make_prior_model(left, right, _model) - vi = _vi === nothing ? VarInfo(deepcopy(model), PriorContext()) : _vi - foreach(keys(vi.metadata)) do n - @assert n in keys(left) "Variable $n is not defined." - end - return getlogp(last(evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)))) -end - -@generated function make_prior_model( - left::NamedTuple{leftnames}, - right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} - argvals = [] - missings = [] - warnings = [] - - for argname in argnames - if argname in leftnames - push!(argvals, :(deepcopy(left.$argname))) - push!(missings, argname) - elseif argname in rightnames - push!(argvals, :(right.$argname)) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) - else - push!(warnings, :(@warn($(warn_msg(argname))))) - push!(argvals, :(model.args.$argname)) - end - end - - # `args` is inserted as properly typed NamedTuple expression; - # `missings` is splatted into a tuple at compile time and inserted as literal - return quote - $(warnings...) - Model{$(Tuple(missings))}( - model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - ) - end -end - -warn_msg(arg) = "Argument $arg is not defined. Using the value from the model." - -function Distributions.loglikelihood( - left::NamedTuple, right::NamedTuple, _model::Model, _vi::Union{Nothing,VarInfo} -) - model = make_likelihood_model(left, right, _model) - vi = _vi === nothing ? VarInfo(deepcopy(model)) : _vi - if isdefined(right, :chain) - # Element-wise likelihood for each value in chain - chain = right.chain - ctx = LikelihoodContext(right) - iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - logps = map(iters) do (sample_idx, chain_idx) - setval!(vi, chain, sample_idx, chain_idx) - model(vi, SampleFromPrior(), ctx) - return getlogp(vi) - end - return reshape(logps, size(chain, 1), size(chain, 3)) - else - # Likelihood without chain - # Rhs values are used in the context - ctx = LikelihoodContext(right) - model(vi, SampleFromPrior(), ctx) - return getlogp(vi) - end -end - -@generated function make_likelihood_model( - left::NamedTuple{leftnames}, - right::NamedTuple{rightnames}, - model::Model{_F,argnames,defaultnames}, -) where {leftnames,rightnames,argnames,defaultnames,_F} - argvals = [] - missings = [] - - for argname in argnames - if argname in leftnames - push!(argvals, :(left.$argname)) - elseif argname in rightnames - push!(argvals, :(right.$argname)) - push!(missings, argname) - elseif argname in defaultnames - push!(argvals, :(model.defaults.$argname)) - elseif argname in argnames - push!(argvals, :(model.args.$argname)) - else - throw( - "This point should not be reached. Please open an issue in the DynamicPPL.jl repository.", - ) - end - end - - # `args` is inserted as properly typed NamedTuple expression; - # `missings` is splatted into a tuple at compile time and inserted as literal - return :(Model{$(Tuple(missings))}( - model.f, $(to_namedtuple_expr(argnames, argvals)), model.defaults - )) -end diff --git a/test/prob_macro.jl b/test/prob_macro.jl deleted file mode 100644 index 254141293..000000000 --- a/test/prob_macro.jl +++ /dev/null @@ -1,72 +0,0 @@ -@testset "prob_macro.jl" begin - @testset "scalar" begin - @model function demo(x) - m ~ Normal() - return x ~ Normal(m, 1) - end - - mval = 3 - xval = 2 - iters = 1000 - - logprior = logpdf(Normal(), mval) - loglike = logpdf(Normal(mval, 1), xval) - logjoint = logprior + loglike - - model = demo(xval) - @test logprob"m = mval | model = model" == logprior - @test logprob"m = mval | x = xval, model = model" == logprior - @test logprob"x = xval | m = mval, model = model" == loglike - @test logprob"x = xval, m = mval | model = model" == logjoint - - varinfo = VarInfo(demo(missing)) - @test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint - - varinfo = VarInfo(demo(xval)) - @test logprob"m = mval | model = model, varinfo = varinfo" == logprior - @test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior - @test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike - end - @testset "vector" begin - n = 5 - @model function demo(x, n) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, I) - end - mval = rand(n) - xval = rand(n) - iters = 1000 - - logprior = logpdf(MvNormal(zeros(n), I), mval) - loglike = logpdf(MvNormal(mval, I), xval) - logjoint = logprior + loglike - - model = demo(xval, n) - @test logprob"m = mval | model = model" == logprior - @test logprob"x = xval | m = mval, model = model" == loglike - @test logprob"x = xval, m = mval | model = model" == logjoint - - varinfo = VarInfo(demo(xval, n)) - @test logprob"m = mval | model = model, varinfo = varinfo" == logprior - @test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike - # Currently, we cannot easily pre-allocate `VarInfo` for vector data - end - @testset "issue190" begin - @model function gdemo(x, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x ~ filldist(Normal(m, sqrt(s)), length(y)) - for i in 1:length(y) - y[i] ~ Normal(x[i], sqrt(s)) - end - end - c = Chains(rand(10, 2), [:m, :s]) - model_gdemo = gdemo([1.0, 0.0], [1.5, 0.0]) - r1 = prob"y = [1.5] | chain=c, model = model_gdemo, x = [1.0]" - r2 = map(c[:s]) do s - # exp(logpdf(..)) not pdf because this is exactly what the prob"" macro does, so we test r1 == r2 - exp(logpdf(Normal(1, sqrt(s)), 1.5)) - end - @test r1 == r2 - end -end diff --git a/test/runtests.jl b/test/runtests.jl index f18167d08..6bc12f294 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,7 +42,6 @@ include("test_util.jl") include("simple_varinfo.jl") include("model.jl") include("sampler.jl") - include("prob_macro.jl") include("independence.jl") include("distribution_wrappers.jl") include("contexts.jl") diff --git a/test/turing/prob_macro.jl b/test/turing/prob_macro.jl deleted file mode 100644 index a86bfed2a..000000000 --- a/test/turing/prob_macro.jl +++ /dev/null @@ -1,101 +0,0 @@ -@testset "prob_macro.jl" begin - @testset "scalar" begin - @model function demo(x) - m ~ Normal() - return x ~ Normal(m, 1) - end - - mval = 3 - xval = 2 - iters = 1000 - - model = demo(xval) - varinfo = VarInfo(model) - chain = MCMCChains.get_sections( - sample(model, IS(), iters; save_state=true), :parameters - ) - chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) - lps = logpdf.(Normal.(chain["m"], 1), xval) - @test logprob"x = xval | chain = chain" == lps - @test logprob"x = xval | chain = chain2, model = model" == lps - @test logprob"x = xval | chain = chain, varinfo = varinfo" == lps - @test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps - - # multiple chains - pchain = chainscat(chain, chain) - pchain2 = chainscat(chain2, chain2) - plps = repeat(lps, 1, 2) - @test logprob"x = xval | chain = pchain" == plps - @test logprob"x = xval | chain = pchain2, model = model" == plps - @test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps - @test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps - end - @testset "vector" begin - n = 5 - @model function demo(x; n) - m ~ MvNormal(zeros(n), I) - return x ~ MvNormal(m, I) - end - mval = rand(n) - xval = rand(n) - iters = 1000 - - model = demo(xval; n) - varinfo = VarInfo(model) - chain = MCMCChains.get_sections( - sample(model, HMC(0.5, 1), iters; save_state=true), :parameters - ) - chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple()) - - names = namesingroup(chain, "m") - lps = [ - logpdf(MvNormal(chain.value[i, names, j], I), xval) for i in 1:size(chain, 1), - j in 1:size(chain, 3) - ] - @test logprob"x = xval | chain = chain" == lps - @test logprob"x = xval | chain = chain2, model = model" == lps - @test logprob"x = xval | chain = chain, varinfo = varinfo" == lps - @test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps - - # multiple chains - pchain = chainscat(chain, chain) - pchain2 = chainscat(chain2, chain2) - plps = repeat(lps, 1, 2) - @test logprob"x = xval | chain = pchain" == plps - @test logprob"x = xval | chain = pchain2, model = model" == plps - @test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps - @test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps - end - @testset "issue#137" begin - @model function model1(y, group, n_groups) - σ ~ truncated(Cauchy(0, 1), 0, Inf) - α ~ filldist(Normal(0, 10), n_groups) - μ = α[group] - return y ~ MvNormal(μ, σ^2 * I) - end - - y = randn(100) - group = rand(1:4, 100) - n_groups = 4 - - chain1 = MCMCChains.get_sections( - sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true), - :parameters, - ) - logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1" - - @model function model2(y, group, n_groups) - σ ~ truncated(Cauchy(0, 1), 0, Inf) - α ~ filldist(Normal(0, 10), n_groups) - for i in 1:length(y) - y[i] ~ Normal(α[group[i]], σ) - end - end - - chain2 = MCMCChains.get_sections( - sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true), - :parameters, - ) - logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2" - end -end diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl index faadd1257..229dd72f9 100644 --- a/test/turing/runtests.jl +++ b/test/turing/runtests.jl @@ -18,6 +18,5 @@ include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_te include("compiler.jl") include("loglikelihoods.jl") include("model.jl") - include("prob_macro.jl") include("varinfo.jl") end