diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2cdde32d0..84799298f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: version: - - '1.3' # minimum supported version + # - '1.3' # minimum supported version - '1' # current stable version os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index 20985b109..844289af5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.16.2" +version = "0.17.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/docs/Project.toml b/docs/Project.toml index 83ce62d5e..aa1315f41 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,11 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] Distributions = "0.25" Documenter = "0.27" +Setfield = "0.7.1, 0.8" StableRNGs = "1" diff --git a/docs/make.jl b/docs/make.jl index 7c4735ff2..3076dc4ff 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; sitename="DynamicPPL", format=Documenter.HTML(), modules=[DynamicPPL], - pages=["Home" => "index.md"], + pages=["Home" => "index.md", "TestUtils" => "test_utils.md"], strict=true, checkdocs=:exports, doctestfilters=[ diff --git a/docs/src/test_utils.md b/docs/src/test_utils.md new file mode 100644 index 000000000..912bd7c51 --- /dev/null +++ b/docs/src/test_utils.md @@ -0,0 +1,5 @@ +# DynamicPPL.TestUtils + +```@autodocs +Modules = [DynamicPPL.TestUtils] +``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1e04d1439..82df0f008 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules using BangBang: BangBang +using Setfield: Setfield using Setfield: Setfield using BangBang: BangBang @@ -31,15 +32,23 @@ import Base: keys, haskey +using BangBang: push!!, empty!!, setindex!! + # VarInfo export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + SimpleVarInfo, + push!!, + empty!!, getlogp, setlogp!, acclogp!, resetlogp!, + setlogp!!, + acclogp!!, + resetlogp!!, get_num_produce, set_num_produce!, reset_num_produce!, @@ -139,13 +148,32 @@ include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") include("threadsafe.jl") +include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") - include("test_utils.jl") +# Deprecations +@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector} +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) + +@deprecate setlogp!(vi, logp) setlogp!!(vi, logp) +@deprecate acclogp!(vi, logp) acclogp!!(vi, logp) +@deprecate resetlogp!(vi) resetlogp!!(vi) + end # module diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 47a627506..edcac7874 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -1,5 +1,5 @@ # See https://github.com/TuringLang/Turing.jl/issues/1199 -ChainRulesCore.@non_differentiable push!( +ChainRulesCore.@non_differentiable push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) @@ -16,7 +16,7 @@ ZygoteRules.@adjoint function dot_observe( ) function dot_observe_fallback(spl, dists, value, vi) increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)) + return sum(map(Distributions.loglikelihood, dists, value)), vi end return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi) end diff --git a/src/compiler.jl b/src/compiler.jl index 4cd4cc899..f973a3bb6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -355,10 +355,12 @@ end function generate_tilde_literal(left, right) # If the LHS is a literal, it is always an observation + @gensym value return quote - $(DynamicPPL.tilde_observe!)( + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) + $value end end @@ -373,7 +375,7 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption + @gensym vn isassumption value # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact # that in DynamicPPL we the entire function body. Instead we should be @@ -389,32 +391,38 @@ function generate_tilde(left, right) $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.tilde_observe!)( + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, __varinfo__, ) + $value end end end function generate_tilde_assume(left, right, vn) - expr = :( - $left = $(DynamicPPL.tilde_assume!)( + # HACK: Because the Setfield.jl macro does not support assignment + # with multiple arguments on the LHS, we need to capture the return-values + # and then update the LHS variables one by one. + @gensym value + expr = :($left = $value) + if left isa Expr + expr = AbstractPPL.drop_escape( + Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + ) + end + + return quote + $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) - ) - - return if left isa Expr - AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) - ) - else - return expr + $expr + $value end end @@ -428,7 +436,7 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption + @gensym vn isassumption value return quote $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) @@ -440,13 +448,14 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.dot_tilde_observe!)( + $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, __varinfo__, ) + $value end end end @@ -455,15 +464,82 @@ function generate_dot_tilde_assume(left, right, vn) # We don't need to use `Setfield.@set` here since # `.=` is always going to be inplace + needs `left` to # be something that supports `.=`. - return :( - $left .= $(DynamicPPL.dot_tilde_assume!)( + @gensym value + return quote + $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn )..., __varinfo__, ) - ) + $left .= $value + $value + end +end + +# Note that we cannot use `MacroTools.isdef` because +# of https://github.com/FluxML/MacroTools.jl/issues/154. +""" + isfuncdef(expr) + +Return `true` if `expr` is any form of function definition, and `false` otherwise. +""" +function isfuncdef(e::Expr) + return if Meta.isexpr(e, :function) + # Classic `function f(...)` + true + elseif Meta.isexpr(e, :->) + # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs. + true + elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call) + # Short function defs, e.g. `f(args...) = ...`. + true + else + false + end +end + +""" + replace_returns(expr) + +Return `Expr` with all `return ...` statements replaced with +`return ..., DynamicPPL.return_values(__varinfo__)`. + +Note that this method will _not_ replace `return` statements within function +definitions. This is checked using [`isfuncdef`](@ref). +""" +replace_returns(e) = e +function replace_returns(e::Expr) + if isfuncdef(e) + return e + end + + if Meta.isexpr(e, :return) + # NOTE: `return` always has an argument. In the case of + # an empty `return`, the lowered expression will be `return nothing`. + # Hence we don't need any special handling for empty returns. + retval_expr = if length(e.args) > 1 + Expr(:tuple, e.args...) + else + e.args[1] + end + + return :(return ($retval_expr, __varinfo__)) + end + + return Expr(e.head, map(replace_returns, e.args)...) +end + +# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. +make_returns_explicit!(body) = Expr(:return, body) +function make_returns_explicit!(body::Expr) + # If the last statement is a return-statement, we don't do anything. + # Otherwise we replace the last statement with a `return` statement. + if !Meta.isexpr(body.args[end], :return) + body.args[end] = Expr(:return, body.args[end]) + end + return body end const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} @@ -496,10 +572,14 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` - # to the call site + # to the call site. + # NOTE: We need to replace statements of the form `return ...` with + # `return (..., __varinfo__)` to ensure that the second + # element in the returned value is always the most up-to-date `__varinfo__`. + # See the docstrings of `replace_returns` for more info. evaluatordef[:body] = MacroTools.@q begin $(linenumbernode) - $(modelinfo[:body]) + $(replace_returns(make_returns_explicit!(modelinfo[:body]))) end ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 19b5ce061..20c4af446 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -54,7 +54,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(context.vars, vn)) + vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) @@ -63,7 +63,7 @@ function tilde_assume( rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(context.vars, vn)) + vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) settrans!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) @@ -71,7 +71,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(context.vars, vn)) + vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) settrans!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) @@ -85,7 +85,7 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, get(context.vars, vn)) + vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) @@ -105,18 +105,17 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, vi) end """ - tilde_assume!(context, right, vn, vi) + tilde_assume!!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value. +accumulate the log probability, and return the sampled value and updated `vi`. By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, vi) - value, logp = tilde_assume(context, right, vn, vi) - acclogp!(vi, logp) - return value +function tilde_assume!!(context, right, vn, vi) + value, logp, vi = tilde_assume(context, right, vn, vi) + return value, acclogp!!(vi, logp) end # observe @@ -140,15 +139,17 @@ function tilde_observe(::IsParent, context::AbstractContext, args...) return tilde_observe(childcontext(context), args...) end -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::PriorContext, right, left, vi) = 0, vi +tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vi) + logp, vi = tilde_observe(context.context, right, left, vi) + return context.loglike_scalar * logp, vi end function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vi) + logp, vi = tilde_observe(context.context, sampler, right, left, vi) + return context.loglike_scalar * logp, vi end # `PrefixContext` @@ -157,16 +158,16 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vi) + tilde_observe!!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value. +accumulate the log probability, and return the observed value and updated `vi`. -Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vi) - return tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vname, vi) + return tilde_observe!!(context, right, left, vi) end """ @@ -178,10 +179,9 @@ return the observed value. By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_observe!(context, right, left, vi) - logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left +function tilde_observe!!(context, right, left, vi) + logp, vi = tilde_observe(context, right, left, vi) + return left, acclogp!!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -195,7 +195,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # SampleFromPrior and SampleFromUniform @@ -204,7 +204,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi, + vi::AbstractVarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. @@ -219,18 +219,18 @@ function assume( end else r = init(rng, dist, sampler) - push!(vi, vn, r, dist, sampler) + push!!(vi, vn, r, dist, sampler) settrans!(vi, false, vn) end - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(right, left) + return Distributions.loglikelihood(right, left), vi end # .~ functions @@ -364,22 +364,24 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, end """ - dot_tilde_assume!(context, right, left, vn, vi) + dot_tilde_assume!!(context, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value. +model inputs), accumulate the log probability, and return the sampled value and updated `vi`. Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, vi) - value, logp = dot_tilde_assume(context, right, left, vn, vi) - acclogp!(vi, logp) - return value +function dot_tilde_assume!!(context, right, left, vn, vi) + value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) + return value, acclogp!!(vi, logp), vi end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) # NOTE: We cannot work with `var` here because we might have a model of the form @@ -389,10 +391,10 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri + lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return r, lp + return r, lp, vi end function dot_assume( @@ -401,12 +403,12 @@ function dot_assume( dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, var::AbstractMatrix, - vi, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -423,7 +425,7 @@ function dot_assume( # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = reshape(vi[vec(vns)], size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -432,12 +434,12 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, var::AbstractArray, - vi, + vi::AbstractVarInfo, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) return error( @@ -447,7 +449,7 @@ end function get_and_set_val!( rng, - vi, + vi::AbstractVarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -471,7 +473,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!(vi, vn, r[:, i], dist, spl) + push!!(vi, vn, r[:, i], dist, spl) settrans!(vi, false, vn) end end @@ -480,7 +482,7 @@ end function get_and_set_val!( rng, - vi, + vi::AbstractVarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -504,14 +506,22 @@ function get_and_set_val!( else f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) - push!.(Ref(vi), vns, r, dists, Ref(spl)) + # TODO: This will inefficient since it will allocate an entire vector. + # We could either: + # 1. Figure out the broadcast size and use a `foreach`. + # 2. Define an anonymous function which returns `nothing`, which + # we then broadcast. This will allocate a vector of `nothing` though. + push!!.(Ref(vi), vns, r, dists, Ref(spl)) settrans!.(Ref(vi), false, vns) end return r end function set_val!( - vi, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix + vi::AbstractVarInfo, + vns::AbstractVector{<:VarName}, + dist::MultivariateDistribution, + val::AbstractMatrix, ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) @@ -520,7 +530,7 @@ function set_val!( return val end function set_val!( - vi, + vi::AbstractVarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, val::AbstractArray, @@ -555,12 +565,13 @@ function dot_tilde_observe(::IsParent, context::AbstractContext, args...) return dot_tilde_observe(childcontext(context), args...) end -dot_tilde_observe(::PriorContext, right, left, vi) = 0 -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) + logp, vi = dot_tilde_observe(context.context, right, left, vi) + return context.loglike_scalar * logp, vi end # `PrefixContext` @@ -569,30 +580,29 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vi) + dot_tilde_observe!!(context, right, left, vname, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value. +accumulate the log probability, and return the observed value and updated `vi`. -Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, vi) - return dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vn, vi) + return dot_tilde_observe!!(context, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vi) + dot_tilde_observe!!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value. +probability, and return the observed value and updated `vi`. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(context, right, left, vi) - logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left +function dot_tilde_observe!!(context, right, left, vi) + logp, vi = dot_tilde_observe(context, right, left, vi) + return left, acclogp!!(vi, logp) end # Falls back to non-sampler definition. @@ -601,13 +611,13 @@ function dot_observe(::AbstractSampler, dist, value, vi) end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(dist, value), vi end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value) + return Distributions.loglikelihood(dists, value), vi end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)) + return sum(Distributions.loglikelihood.(dists, value)), vi end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index cd50811c1..daf05eedd 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -67,27 +67,26 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return tilde_observe!(context.context, right, left, vi) + return tilde_observe!!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. - logp = tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) + logp, vi = tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!!(vi, logp) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe!(context.context, right, left, vi) + return dot_tilde_observe!!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -95,7 +94,6 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn # hence we need the `logp` for each of them. Broadcasting the univariate # `tilde_obseve` does exactly this. logps = _pointwise_tilde_observe(context.context, right, left, vi) - acclogp!(vi, sum(logps)) # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. _, _, vns = unwrap_right_left_vns(right, left, vn) @@ -104,19 +102,25 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn push!(context, vn, logp) end - return left + return left, acclogp!!(vi, sum(logps)) end # FIXME: This is really not a good approach since it needs to stay in sync with # the `dot_assume` implementations, but as things are _right now_ this is the best we can do. function _pointwise_tilde_observe(context, right, left, vi) - return tilde_observe.(Ref(context), right, left, Ref(vi)) + # We need to drop the `vi` returned. + return broadcast(right, left) do r, l + return first(tilde_observe(context, r, l, vi)) + end end function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi + context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo ) - return tilde_observe.(Ref(context), Ref(right), eachcol(left), Ref(vi)) + # We need to drop the `vi` returned. + return map(eachcol(left)) do l + return first(tilde_observe(context, right, l, vi)) + end end """ diff --git a/src/model.jl b/src/model.jl index bc149838b..702d76a17 100644 --- a/src/model.jl +++ b/src/model.jl @@ -195,7 +195,7 @@ julia> @model demo_inner() = m ~ Normal() demo_inner (generic function with 2 methods) julia> @model function demo_outer() - m = @submodel demo_inner() + @submodel m = demo_inner() return m end demo_outer (generic function with 2 methods) @@ -215,7 +215,7 @@ But one needs to be careful when prefixing variables in the nested models: ```jldoctest condition julia> @model function demo_outer_prefix() - m = @submodel inner demo_inner() + @submodel prefix="inner" m = demo_inner() return m end demo_outer_prefix (generic function with 2 methods) @@ -374,55 +374,71 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +(model::Model)(args...) = first(evaluate!!(model, args...)) + +""" + evaluate!!(model::Model[, rng, varinfo, sampler, context]) + +Sample from the `model` using the `sampler` with random number generator `rng` and the +`context`, and store the sample and log joint probability in `varinfo`. + +Returns both the return-value of the original model, and the resulting varinfo. + +The method resets the log joint probability of `varinfo` and increases the evaluation +number of `sampler`. +""" +function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) + if Threads.nthreads() == 1 + return evaluate_threadunsafe!!(model, varinfo, context) + else + return evaluate_threadsafe!!(model, varinfo, context) + end +end + +function evaluate!!( + model::Model, rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return model(varinfo, SamplingContext(rng, sampler, context)) + return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) - if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, context) - else - return evaluate_threadsafe(model, varinfo, context) - end -end +evaluate!!(model::Model, context::AbstractContext) = evaluate!!(model, VarInfo(), context) -function (model::Model)(args...) - return model(Random.GLOBAL_RNG, args...) +function evaluate!!(model::Model, args...) + return evaluate!!(model, Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) - return model(rng, VarInfo(), sampler, args...) +function evaluate!!( + model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args... +) + return evaluate!!(model, rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) - return model(rng, VarInfo(), SampleFromPrior(), context) +function evaluate!!(model::Model, rng::Random.AbstractRNG, context::AbstractContext) + return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) end """ - evaluate_threadunsafe(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. This method is not exposed and supposed to be used only internally in DynamicPPL. -See also: [`evaluate_threadsafe`](@ref) +See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe(model, varinfo, context) - resetlogp!(varinfo) - return _evaluate(model, varinfo, context) +function evaluate_threadunsafe!!(model, varinfo, context) + return _evaluate!!(model, resetlogp!!(varinfo), context) end """ - evaluate_threadsafe(model, varinfo, context) + evaluate_threadsafe!!(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -430,22 +446,20 @@ With the wrapper, Julia's multithreading can be used for observe statements in t but parallel sampling will lead to undefined behaviour. This method is not exposed and supposed to be used only internally in DynamicPPL. -See also: [`evaluate_threadunsafe`](@ref) +See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe(model, varinfo, context) - resetlogp!(varinfo) - wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, context) - setlogp!(varinfo, getlogp(wrapper)) - return result +function evaluate_threadsafe!!(model, varinfo, context) + wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) + result, wrapper_new = _evaluate!!(model, wrapper, context) + return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) end """ - _evaluate(model::Model, varinfo, context) + _evaluate!!(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ -@generated function _evaluate( +@generated function _evaluate!!( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [ @@ -495,8 +509,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) - return getlogp(varinfo) + return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -507,8 +520,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, PriorContext()) - return getlogp(varinfo) + return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) end """ @@ -519,8 +531,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, LikelihoodContext()) - return getlogp(varinfo) + return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) end """ diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..c87e365ea 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,8 +146,9 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - model(vi, SampleFromPrior(), PriorContext(left)) - return getlogp(vi) + return getlogp( + last(DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left))) + ) end @generated function make_prior_model( diff --git a/src/sampler.jl b/src/sampler.jl index 664031233..bb3ab7633 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -82,7 +82,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if haskey(kwargs, :init_params) - initialize_parameters!(vi, kwargs[:init_params], spl) + vi = initialize_parameters!!(vi, kwargs[:init_params], spl) # Update joint log probability. # TODO: fix properly by using sampler and evaluation contexts @@ -116,7 +116,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) +function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) @debug "Using passed-in initial variable values" init_params # Flatten parameters. @@ -126,7 +126,10 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) # Get all values. linked = islinked(vi, spl) - linked && invlink!(vi, spl) + if linked + # TODO: Make work with immutable `vi`. + invlink!(vi, spl) + end theta = vi[spl] length(theta) == length(init_theta) || error("Provided initial value doesn't match the dimension of the model") @@ -140,10 +143,13 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) end # Update in `vi`. - vi[spl] = theta - linked && link!(vi, spl) + vi = setindex!!(vi, theta, spl) + if linked + # TODO: Make work with immutable `vi`. + link!(vi, spl) + end - return nothing + return vi end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..5cecda4b2 --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,520 @@ +""" + SimpleVarInfo{NT,T} <: AbstractVarInfo + +A simple wrapper of the parameters with a `logp` field for +accumulation of the logdensity. + +Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. + +# Notes +The major differences between this and `TypedVarInfo` are: +1. `SimpleVarInfo` does not require linearization. +2. `SimpleVarInfo` can use more efficient bijectors. +3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either + a) no indexing is used in tilde-statements, or + b) the values have been specified with the correct shapes. + +# Examples +## General usage +```jldoctest; setup=:(using Distributions) +julia> using StableRNGs + +julia> @model function demo() + m ~ Normal() + x = Vector{Float64}(undef, 2) + for i in eachindex(x) + x[i] ~ Normal() + end + return x + end +demo (generic function with 2 methods) + +julia> m = demo(); + +julia> rng = StableRNG(42); + +julia> ### Sampling ### + ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); + +julia> # In the `NamedTuple` version we need to provide the place-holder values for + # the variables which are using "containers", e.g. `Array`. + # In this case, this means that we need to specify `x` but not `m`. + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); + +julia> # (✓) Vroom, vroom! FAST!!! + vi[@varname(x[1])] +0.4471218424633827 + +julia> # We can also access arbitrary varnames pointing to `x`, e.g. + vi[@varname(x)] +2-element Vector{Float64}: + 0.4471218424633827 + 1.3736306979834252 + +julia> vi[@varname(x[1:2])] +2-element Vector{Float64}: + 0.4471218424633827 + 1.3736306979834252 + +julia> # (×) If we don't provide the container... + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi +ERROR: type NamedTuple has no field x +[...] + +julia> # If one does not know the varnames, we can use a `Dict` instead. + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx); + +julia> # (✓) Sort of fast, but only possible at runtime. + vi[@varname(x[1])] +-1.019202452456547 + +julia> # In addtion, we can only access varnames as they appear in the model! + vi[@varname(x)] +ERROR: KeyError: key x not found +[...] + +julia> vi[@varname(x[1:2])] +ERROR: KeyError: key x[1:2] not found +[...] +``` + +## Indexing +Using `NamedTuple` as underlying storage. + +```jldoctest +julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); + +julia> svi_nt[@varname(m)] +(a = [1.0],) + +julia> svi_nt[@varname(m.a)] +1-element Vector{Float64}: + 1.0 + +julia> svi_nt[@varname(m.a[1])] +1.0 + +julia> svi_nt[@varname(m.a[2])] +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> svi_nt[@varname(m.b)] +ERROR: type NamedTuple has no field b +[...] +``` + +Using `Dict` as underlying storage. +```jldoctest +julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))); + +julia> svi_dict[@varname(m)] +(a = [1.0],) + +julia> svi_dict[@varname(m.a)] +1-element Vector{Float64}: + 1.0 + +julia> svi_dict[@varname(m.a[1])] +1.0 + +julia> svi_dict[@varname(m.a[2])] +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> svi_dict[@varname(m.b)] +ERROR: type NamedTuple has no field b +[...] +``` +""" +struct SimpleVarInfo{NT,T} <: AbstractVarInfo + values::NT + logp::T +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) +SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) +SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) + +# Constructor from `Model`. +SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} + return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) +end + +# Constructor from `VarInfo`. +function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} + return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +end +function SimpleVarInfo{T}( + vi::VarInfo{<:NamedTuple{names}}, ::Type{D} +) where {T<:Real,names,D} + values = values_as(vi, D) + return SimpleVarInfo(values, convert(T, getlogp(vi))) +end + +function BangBang.empty!!(vi::SimpleVarInfo) + Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) +end + +getlogp(vi::SimpleVarInfo) = vi.logp +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) + +""" + keys(vi::SimpleVarInfo) + +Return an iterator of keys present in `vi`. +""" +Base.keys(vi::SimpleVarInfo) = keys(vi.values) + +function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) + vi.logp[] = logp + return vi +end + +function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) + vi.logp[] += logp + return vi +end + +function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") +end + +# `NamedTuple` +Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) + +# `Dict` +function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) + if haskey(vi.values, vn) + return vi.values[vn] + end + + # Split the lens into the key / `parent` and the extraction lens / `child`. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(vi.values, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # If we found a valid split, then we can extract the value. + if !issuccess + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) + end + + # TODO: Should we also check that we `canview` the extracted `value` + # rather than just let it fail upon `get` call? + value = vi.values[VarName(vn, keylens)] + return get(value, child) +end + +# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than +# just `Vector`. +function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(getindex, vi), vns) +end +# HACK: Needed to disambiguiate. +Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) + +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values +# TODO: Should we do better? +Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values + +Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) +function _haskey(nt::NamedTuple, vn::VarName) + # LHS: Ensure that `nt` indeed has the property we want. + # RHS: Ensure that the lens can view into `nt`. + sym = getsym(vn) + return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) +end + +# For `dictlike` we need to check wether `vn` is "immediately" present, or +# if some ancestor of `vn` is present in `dictlike`. +function _haskey(dict::AbstractDict, vn::VarName) + # First we check if `vn` is present as is. + haskey(dict, vn) && return true + + # If `vn` is not present, we check any parent-varnames by attempting + # to split the lens into the key / `parent` and the extraction lens / `child`. + # If `issuccess` is `true`, we found such a split, and hence `vn` is present. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(dict, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # Return early if no such split could be found. + issuccess || return false + + # At this point we just need to check that we `canview` the value. + value = dict[VarName(vn, keylens)] + + return canview(child, value) +end + +function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) + # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. + return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) +end + +# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with +# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. +function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) + for (vn, val) in zip(vns, vals) + vi = BangBang.setindex!!(vi, val, vn) + end + return vi +end + +function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) + # For dictlike objects, we treat the entire `vn` as a _key_ to set. + dict = values_as(vi) + # Attempt to split into `parent` and `child` lenses. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(dict, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + dict_new = if !issuccess + # Split doesn't exist ⟹ we're working with a new key. + BangBang.setindex!!(dict, val, vn) + else + # Split exists ⟹ trying to set an existing key. + vn_key = VarName(vn, keylens) + BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) + end + return SimpleVarInfo(dict_new, vi.logp) +end + +# `NamedTuple` +function BangBang.push!!( + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym,Setfield.IdentityLens}, + value, + dist::Distribution, + gidset::Set{Selector}, +) where {sym} + return Setfield.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) +end +function BangBang.push!!( + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym}, + value, + dist::Distribution, + gidset::Set{Selector}, +) where {sym} + return Setfield.@set vi.values = set!!(vi.values, vn, value) +end + +# `Dict` +function BangBang.push!!( + vi::SimpleVarInfo{<:AbstractDict}, + vn::VarName, + r, + dist::Distribution, + gidset::Set{Selector}, +) + vi.values[vn] = r + return vi +end + +const SimpleOrThreadSafeSimple{T,V} = Union{ + SimpleVarInfo{T,V},ThreadSafeVarInfo{<:SimpleVarInfo{T,V}} +} + +# Necessary for `matchingvalue` to work properly. +function Base.eltype( + vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior} +) where {V} + return V +end + +# Context implementations +function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) + left = vi[vn] + return left, Distributions.loglikelihood(dist, left), vi +end + +function assume( + rng::Random.AbstractRNG, + sampler::SampleFromPrior, + dist::Distribution, + vn::VarName, + vi::SimpleOrThreadSafeSimple, +) + value = init(rng, dist, sampler) + vi = BangBang.push!!(vi, vn, value, dist, sampler) + return value, Distributions.loglikelihood(dist, value), vi +end + +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::SimpleOrThreadSafeSimple, +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + value = vi[vns] + lp = sum(zip(vns, eachcol(value))) do (vn, val) + return Distributions.logpdf(dist, val) + end + return value, lp, vi +end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi::SimpleOrThreadSafeSimple, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + value = vi[vns] + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi +end + +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi::SimpleOrThreadSafeSimple, +) + f = (vn, dist) -> init(rng, dist, spl) + value = f.(vns, dists) + vi = BangBang.setindex!!(vi, value, vns) + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi +end + +# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. +increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing +settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing +istrans(::SimpleVarInfo, vn::VarName) = false + +""" + values_as(varinfo[, Type]) + +Return the values/realizations in `varinfo` as `Type`, if implemented. + +If no `Type` is provided, return values as stored in `varinfo`. +""" +values_as(vi::SimpleVarInfo) = vi.values +values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) +values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) +values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values + +""" + logjoint(model::Model, θ) + +Return the log joint probability of variables `θ` for the probabilistic `model`. + +See [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `Dict`. + logjoint(demo([1.0]), Dict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +logjoint(model::Model, θ) = logjoint(model, SimpleVarInfo(θ)) + +""" + logprior(model::Model, θ) + +Return the log prior probability of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `Dict`. + logprior(demo([1.0]), Dict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +logprior(model::Model, θ) = logprior(model, SimpleVarInfo(θ)) + +""" + loglikelihood(model::Model, θ) + +Return the log likelihood of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `Dict`. + loglikelihood(demo([1.0]), Dict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 03c6f2ad6..5ffed3c42 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,9 @@ """ @submodel model + @submodel ... = model Run a Turing `model` nested inside of a Turing model. -The return value can be assigned to a variable. - # Examples ```jldoctest submodel; setup=:(using Distributions) @@ -14,7 +13,7 @@ julia> @model function demo1(x) end; julia> @model function demo2(x, y) - a = @submodel demo1(x) + @submodel a = demo1(x) return y ~ Uniform(0, a) end; ``` @@ -44,24 +43,27 @@ true ``` """ macro submodel(expr) - return quote - _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) - end + return submodel(:(prefix = false), expr) end """ - @submodel prefix model + @submodel prefix=... model + @submodel prefix=... ... = model Run a Turing `model` nested inside of a Turing model and add "`prefix`." as a prefix to all random variables inside of the `model`. +Valid expressions for `prefix=...` are: +- `prefix=false`: no prefix is used. +- `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side + `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. +- `prefix=expression`: results in the prefix `Symbol(expression)`. + The prefix makes it possible to run the same Turing model multiple times while keeping track of all random variables correctly. -The return value can be assigned to a variable. - # Examples - +## Example models ```jldoctest submodelprefix; setup=:(using Distributions) julia> @model function demo1(x) x ~ Normal() @@ -69,8 +71,8 @@ julia> @model function demo1(x) end; julia> @model function demo2(x, y, z) - a = @submodel sub1 demo1(x) - b = @submodel sub2 demo1(y) + @submodel prefix="sub1" a = demo1(x) + @submodel prefix="sub2" b = demo1(y) return z ~ Uniform(-a, b) end; ``` @@ -111,13 +113,131 @@ julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); julia> getlogp(vi) ≈ logprior + loglikelihood true ``` + +## Different ways of setting the prefix +```jldoctest submodel-prefix-alternatives; setup=:(using DynamicPPL, Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> # When `prefix` is unspecified, no prefix is used. + @model outer() = @submodel a = inner() +outer (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(outer())) +true + +julia> # Explicitely don't use any prefix. + @model outer() = @submodel prefix=false a = inner() +outer (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(outer())) +true + +julia> # Automatically determined from `a`. + @model outer() = @submodel prefix=true a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"a.x") in keys(VarInfo(outer())) +true + +julia> # Using a static string. + @model outer() = @submodel prefix="my prefix" a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"my prefix.x") in keys(VarInfo(outer())) +true + +julia> # Using string interpolation. + @model outer() = @submodel prefix="\$(inner().name)" a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"inner.x") in keys(VarInfo(outer())) +true + +julia> # Or using some arbitrary expression. + @model outer() = @submodel prefix=1 + 2 a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"3.x") in keys(VarInfo(outer())) +true + +julia> # (×) Automatic prefixing without a left-hand side expression does not work! + @model outer() = @submodel prefix=true inner() +ERROR: LoadError: cannot automatically prefix with no left-hand side +[...] +``` + +# Notes +- The choice `prefix=expression` means that the prefixing will incur a runtime cost. + This is also the case for `prefix=true`, depending on whether the expression on the + the right-hand side of `... = model` requires runtime-information or not, e.g. + `x = model` will result in the _static_ prefix `x`, while `x[i] = model` will be + resolved at runtime. """ -macro submodel(prefix, expr) - return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), - ) +macro submodel(prefix_expr, expr) + return submodel(prefix_expr, expr, esc(:__context__)) +end + +# Automatic prefixing. +function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) + return prefix ? prefix_submodel_context(left, ctx) : ctx +end + +function prefix_submodel_context(prefix::Bool, left::Expr, ctx) + return prefix ? prefix_submodel_context(varname(left), ctx) : ctx +end + +# Manual prefixing. +prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) +function prefix_submodel_context(prefix, ctx) + # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. + return :($(DynamicPPL.PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) +end + +function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) + # E.g. `prefix="asd"`. + return :($(DynamicPPL.PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) +end + +function prefix_submodel_context(prefix::Bool, ctx) + if prefix + error("cannot automatically prefix with no left-hand side") + end + + return ctx +end + +function submodel(prefix_expr, expr, ctx=esc(:__context__)) + prefix_left, prefix = getargs_assignment(prefix_expr) + if prefix_left !== :prefix + error("$(prefix_left) is not a valid kwarg") + end + # `prefix=false` => don't prefix, i.e. do nothing to `ctx`. + # `prefix=true` => automatically determine prefix. + # `prefix=...` => use it. + args_assign = getargs_assignment(expr) + return if args_assign === nothing + ctx = prefix_submodel_context(prefix, ctx) + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:__varinfo__)) = last( + $(DynamicPPL._evaluate!!)($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + ) + end + else + L, R = args_assign + # Now that we have `L` and `R`, we can prefix automagically. + try + ctx = prefix_submodel_context(prefix, L, ctx) + catch e + error( + "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", + ) + end + quote + $(esc(L)), $(esc(:__varinfo__)) = $(DynamicPPL._evaluate!!)( + $(esc(R)), $(esc(:__varinfo__)), $(ctx) + ) + end end end diff --git a/src/test_utils.jl b/src/test_utils.jl index d4b5c7206..e9b0e6a7d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,49 @@ using LinearAlgebra using Distributions using Test +""" + logprior_true(model, θ) + +Return the `logprior` of `model` for `θ`. + +This should generally be implemented by hand for every specific `model`. + +See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). +""" +function logprior_true end + +""" + loglikelihood_true(model, θ) + +Return the `loglikelihood` of `model` for `θ`. + +This should generally be implemented by hand for every specific `model`. + +See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). +""" +function loglikelihood_true end + +""" + logjoint_true(model, θ) + +Return the `logjoint` of `model` for `θ`. + +Defaults to `logprior_true(model, θ) + loglikelihood_true(model, θ)`. + +This should generally be implemented by hand for every specific `model` +so that the returned value can be used as a ground-truth for testing things like: + +1. Validity of evaluation of `model` using a particular implementation of `AbstractVarInfo`. +2. Validity of a sampler when combined with DynamicPPL by running the sampler twice: once targeting ground-truth functions, e.g. `logjoint_true`, and once targeting `model`. + +And more. + +See also: [`logprior_true`](@ref), [`loglikelihood_true`](@ref). +""" +function logjoint_true(model::Model, args...) + return logprior_true(model, args...) + loglikelihood_true(model, args...) +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( @@ -17,6 +60,12 @@ using Test x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -30,14 +79,26 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end -@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` m ~ MvNormal(zero(x), I) x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) + return logpdf(MvNormal(zero(model.args.x), I), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -51,6 +112,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -61,6 +128,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) + return logpdf(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -69,6 +142,12 @@ end return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) + return logpdf(MvNormal(zeros(2), I), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) + return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -80,6 +159,12 @@ end return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) + return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -88,6 +173,12 @@ end return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) + return logpdf(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) + return logpdf(Normal(m, 0.5), 10.0) +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -98,13 +189,21 @@ end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - m = @submodel _prior_dot_assume() + @submodel m = _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m +) + return sum(logpdf.(Normal.(m, 0.5), 10.0)) +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -121,6 +220,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} @@ -133,11 +238,17 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), - demo_assume_multivariate_observe_index(), + demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), demo_assume_dot_observe(), demo_assume_observe_literal(), diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c940f9e3f..6f020a352 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. -function acclogp!(vi::ThreadSafeVarInfo, logp) +function acclogp!!(vi::ThreadSafeVarInfo, logp) vi.logps[Threads.threadid()][] += logp return vi end @@ -26,17 +26,17 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps) # TODO: Make remaining methods thread-safe. -function resetlogp!(vi::ThreadSafeVarInfo) +function resetlogp!!(vi::ThreadSafeVarInfo) for x in vi.logps x[] = zero(x[]) end - return resetlogp!(vi.varinfo) + return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) end -function setlogp!(vi::ThreadSafeVarInfo, logp) +function setlogp!!(vi::ThreadSafeVarInfo, logp) for x in vi.logps x[] = zero(x[]) end - return setlogp!(vi.varinfo, logp) + return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -65,14 +65,21 @@ getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, s getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) -function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end -function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end -function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) +end + +function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) +end +function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) @@ -80,16 +87,14 @@ function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function empty!(vi::ThreadSafeVarInfo) - empty!(vi.varinfo) - fill!(vi.logps, zero(getlogp(vi))) - return vi +function BangBang.empty!!(vi::ThreadSafeVarInfo) + return resetlogp!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) end -function push!( +function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return push!(vi.varinfo, vn, r, dist, gidset) + return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) end function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) diff --git a/src/utils.jl b/src/utils.jl index db7faabbd..ffdc21070 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) end end @@ -44,6 +44,20 @@ function getargs_tilde(expr::Expr) end end +""" + getargs_assignment(x) + +Return the arguments `L` and `R`, if `x` is an expression of the form `L = R`, or `nothing` +otherwise. +""" +getargs_assignment(x) = nothing +function getargs_assignment(expr::Expr) + return MacroTools.@match expr begin + (L_ = R_) => (L, R) + x_ => nothing + end +end + function to_namedtuple_expr(syms, vals=syms) length(syms) == 0 && return :(NamedTuple()) @@ -66,11 +80,10 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). -reconstruct(d::UnivariateDistribution, val::AbstractVector) = val[1] -reconstruct(d::MultivariateDistribution, val::AbstractVector) = copy(val) -function reconstruct(d::MatrixDistribution, val::AbstractVector) - return reshape(copy(val), size(d)) -end +reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) +reconstruct(::Tuple{}, val::AbstractVector) = val[1] +reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) +reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) function reconstruct!(r, d::Distribution, val::AbstractVector) return reconstruct!(r, d, val) end @@ -79,17 +92,17 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector) return r end function reconstruct(d::Distribution, val::AbstractVector, n::Int) - return reconstruct(d, val, n) + return reconstruct(size(d), val, n) end -function reconstruct(d::UnivariateDistribution, val::AbstractVector, n::Int) +function reconstruct(::Tuple{}, val::AbstractVector, n::Int) return copy(val) end -function reconstruct(d::MultivariateDistribution, val::AbstractVector, n::Int) - return copy(reshape(val, size(d)[1], n)) +function reconstruct(s::NTuple{1}, val::AbstractVector, n::Int) + return copy(reshape(val, s[1], n)) end -function reconstruct(d::MatrixDistribution, val::AbstractVector, n::Int) - tmp = reshape(val, size(d)[1], size(d)[2], n) - orig = [tmp[:, :, i] for i in 1:size(tmp, 3)] +function reconstruct(s::NTuple{2}, val::AbstractVector, n::Int) + tmp = reshape(val, s..., n) + orig = [tmp[:, :, i] for i in 1:n] return orig end function reconstruct!(r, d::Distribution, val::AbstractVector, n::Int) @@ -142,3 +155,162 @@ end ####################### collectmaybe(x) = x collectmaybe(x::Base.AbstractSet) = collect(x) + +####################### +# BangBang.jl related # +####################### +function set!!(obj, lens::Setfield.Lens, value) + lensmut = BangBang.prefermutation(lens) + return Setfield.set(obj, lensmut, value) +end +function set!!(obj, vn::VarName{sym}, value) where {sym} + lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + return Setfield.set(obj, lens, value) +end + +############################# +# AbstractPPL.jl extensions # +############################# +# This is preferable to `haskey` because the order of arguments is different, and +# we're more likely to specialize on the key in these settings rather than the container. +# TODO: I'm not sure about this name. +""" + canview(lens, container) + +Return `true` if `lens` can be used to view `container`, and `false` otherwise. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: canview) +julia> canview(@lens(_.a), (a = 1.0, )) +true + +julia> canview(@lens(_.a), (b = 1.0, )) # property `a` does not exist +false + +julia> canview(@lens(_.a[1]), (a = [1.0, 2.0], )) +true + +julia> canview(@lens(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +false +``` +""" +canview(lens, container) = false +canview(::Setfield.IdentityLens, _) = true +function canview(lens::Setfield.PropertyLens{field}, x) where {field} + return hasproperty(x, field) +end + +# `IndexLens`: only relevant if `x` supports indexing. +canview(lens::Setfield.IndexLens, x) = false +canview(lens::Setfield.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) + +# `ComposedLens`: check that we can view `.outer` and `.inner`, but using +# value extracted using `.outer`. +function canview(lens::Setfield.ComposedLens, x) + return canview(lens.outer, x) && canview(lens.inner, get(x, lens.outer)) +end + +""" + parent(vn::VarName) + +Return the parent `VarName`. + +# Examples +```julia-repl; setup=:(using DynamicPPL: parent) +julia> parent(@varname(x.a[1])) +x.a + +julia> (parent ∘ parent)(@varname(x.a[1])) +x + +julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) +x +``` +""" +function parent(vn::VarName) + p = parent(getlens(vn)) + return p === nothing ? VarName(vn, Setfield.IdentityLens()) : VarName(vn, p) +end + +""" + parent(lens::Setfield.Lens) + +Return the parent lens. If `lens` doesn't have a parent, +`nothing` is returned. + +See also: [`parent_and_child`]. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: parent) +julia> parent(@lens(_.a[1])) +(@lens _.a) + +julia> # Parent of lens without parents results in `nothing`. + (parent ∘ parent)(@lens(_.a[1])) === nothing +true +``` +""" +parent(lens::Setfield.Lens) = first(parent_and_child(lens)) + +""" + parent_and_child(lens::Setfield.Lens) + +Return a 2-tuple of lenses `(parent, child)` where `parent` is the +parent lens of `lens` and `child` is the child lens of `lens`. + +If `lens` does not have a parent, we return `(nothing, lens)`. + +See also: [`parent`]. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: parent_and_child) +julia> parent_and_child(@lens(_.a[1])) +((@lens _.a), (@lens _[1])) + +julia> parent_and_child(@lens(_.a)) +(nothing, (@lens _.a)) +``` +""" +parent_and_child(lens::Setfield.Lens) = (nothing, lens) +function parent_and_child(lens::Setfield.ComposedLens) + p, child = parent_and_child(lens.inner) + parent = p === nothing ? lens.outer : lens.outer ∘ p + return parent, child +end + +""" + splitlens(condition, lens) + +Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, +`parent` is a lens such that `condition(parent)` is `true` and `parent ∘ child == lens`. + +If `issuccess` is `false`, then no such split could be found. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: splitlens) +julia> p, c, issucesss = splitlens(@lens(_.a[1])) do parent + # Succeeds! + parent == @lens(_.a) + end +((@lens _.a), (@lens _[1]), true) + +julia> p ∘ c +(@lens _.a[1]) + +julia> splitlens(@lens(_.a[1])) do parent + # Fails! + parent == @lens(_.b) + end +(nothing, (@lens _.a[1]), false) +``` +""" +function splitlens(condition, lens) + current_parent, current_child = parent_and_child(lens) + # We stop if either a) `condition` is satisfied, or b) we reached the root. + while !condition(current_parent) && current_parent !== nothing + current_parent, c = parent_and_child(current_parent) + current_child = c ∘ current_child + end + + return current_parent, current_child, condition(current_parent) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 451060b0b..9ce0414d6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -588,16 +588,16 @@ end TypedVarInfo(vi::TypedVarInfo) = vi """ - empty!(vi::VarInfo) + empty!!(vi::VarInfo) Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to zeros. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ -function empty!(vi::VarInfo) +function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!(vi) + resetlogp!!(vi) reset_num_produce!(vi) return vi end @@ -655,34 +655,34 @@ Return the log of the joint probability of the observed data and parameters samp getlogp(vi::AbstractVarInfo) = vi.logp[] """ - setlogp!(vi::VarInfo, logp) + setlogp!!(vi::VarInfo, logp) Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`. +`vi` to `logp`, mutating if it makes sense. """ -function setlogp!(vi::VarInfo, logp) +function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end """ - acclogp!(vi::VarInfo, logp) + acclogp!!(vi::VarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`. +parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!(vi::VarInfo, logp) +function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end """ - resetlogp!(vi::AbstractVarInfo) + resetlogp!!(vi::AbstractVarInfo) Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0. +sampled in `vi` to 0, mutating if it makes sense. """ -resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) """ get_num_produce(vi::VarInfo) @@ -955,7 +955,10 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) +setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) +function BangBang.setindex!!(vi::AbstractVarInfo, val, vn::VarName) + return (setindex!(vi, val, vn); return vi) +end """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) @@ -970,8 +973,14 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` ranges = _getranges(vi, spl) _setindex!(vi.metadata, val, ranges) - return val + return nothing end + +function BangBang.setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) + setindex!(vi, val, spl) + return vi +end + # Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. @generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} expr = Expr(:block) @@ -1088,46 +1097,52 @@ function Base.show(io::IO, vi::UntypedVarInfo) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`. +the `VarInfo` `vi`, mutating if it makes sense. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!(vi, vn, r, dist, Set{Selector}([])) +function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`. +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - return push!(vi, vn, r, dist, spl.selector) +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler +) + return BangBang.push!!(vi, vn, r, dist, spl.selector) end -function push!( +function BangBang.push!!( vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler ) - return push!(vi, vn, r, dist) + return BangBang.push!!(vi, vn, r, dist) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!(vi, vn, r, dist, Set([gid])) +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) + return BangBang.push!!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function BangBang.push!!( + vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end val = vectorize(dist, r) @@ -1181,7 +1196,8 @@ end Set `vn`'s value for `flag` to `false` in `vi`. """ function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false + getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false + return vi end """ @@ -1390,7 +1406,7 @@ function setval!( return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1471,7 +1487,7 @@ function setval_and_resample!( return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1485,3 +1501,37 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, return indices end + +""" + values_as(vi::AbstractVarInfo) +""" +values_as(vi::VarInfo) = vi.metadata + +""" + values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) + values_as(vi::AbstractVarInfo, ::Type{Dict}) + +Return values in `vi` as the specified type. +""" +function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) + iter = values_from_metadata(vi.metadata) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end +values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata)) + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return Dict(iter) +end + +function values_from_metadata(md::Metadata) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) +end diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 3a8058ca9..f76ce6f6e 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -30,9 +30,10 @@ # https://github.com/TuringLang/Turing.jl/issues/1595 @testset "dot_observe" begin function f_dot_observe(x) - return DynamicPPL.dot_observe( + logp, _ = DynamicPPL.dot_observe( SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() ) + return logp end function f_dot_observe_manual(x) return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) diff --git a/test/compiler.jl b/test/compiler.jl index 25927c581..4c76cf1ab 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -420,8 +420,8 @@ end end @model function demo_useval(x, y) - x1 = @submodel sub1 demo_return(x) - x2 = @submodel sub2 demo_return(y) + @submodel prefix = "sub1" x1 = demo_return(x) + @submodel prefix = "sub2" x2 = demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -455,8 +455,8 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) - y[i] ~ MvNormal(x, 0.01 * I) + @submodel prefix = "ar1_$i" x = AR1(num_steps, α, μ, σ) + y[i] ~ MvNormal(x, 0.1) end end @@ -544,4 +544,31 @@ end f(::Model{typeof(demo),()}) = false @test !f(demo()) end + + @testset "return value" begin + # Even if the return-value is `AbstractVarInfo`, we should return + # a `Tuple` with `AbstractVarInfo` in the second component too. + @model demo() = return __varinfo__ + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + @test svi == SimpleVarInfo() + if Threads.nthreads() > 1 + @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} + @test retval.varinfo == svi + else + @test retval == svi + end + + # We should not be altering return-values other than at top-level. + @model function demo() + # If we also replaced this `return` inside of `f`, then the + # final `return` would be include `__varinfo__`. + f(x) = return x^2 + return f(1.0) + end + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + @test retval isa Float64 + + @model demo() = x ~ Normal() + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + end end diff --git a/test/model.jl b/test/model.jl index 2cdeae5fa..466a7d1f4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -65,4 +65,20 @@ @test nameof(test1(rand())) == :test1 @test nameof(test2(rand())) == :test2 end + + @testset "Internal methods" begin + model = gdemo_default + + # sample from model and extract variables + vi = VarInfo(model) + + # Second component of return-value of `evaluate!!` should + # be a `DynamicPPL.AbstractVarInfo`. + evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) + @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo + + # Should not return `AbstractVarInfo` when we call the model. + call_retval = model() + @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) + end end diff --git a/test/runtests.jl b/test/runtests.jl index bb2ae579c..3372b6371 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using MacroTools using MCMCChains using Tracker using Zygote +using Setfield using Distributed using LinearAlgebra @@ -34,6 +35,7 @@ include("test_util.jl") include("utils.jl") include("compiler.jl") include("varinfo.jl") + include("simple_varinfo.jl") include("model.jl") include("sampler.jl") include("prob_macro.jl") diff --git a/test/sampler.jl b/test/sampler.jl index b32ba29d6..ffd0182e8 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -18,7 +18,7 @@ @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.1 + @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 chains = sample(model, SampleFromUniform(), N; progress=false) @test chains isa Vector{<:VarInfo} diff --git a/test/serialization.jl b/test/serialization.jl index 2f2bf2a2b..7ea81e410 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -10,7 +10,7 @@ samples_s = first.(samples) samples_m = last.(samples) - @test mean(samples_s) ≈ 3 atol = 0.1 + @test mean(samples_s) ≈ 3 atol = 0.2 @test mean(samples_m) ≈ 0 atol = 0.1 end @testset "pmap" begin diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl new file mode 100644 index 000000000..9e596af5f --- /dev/null +++ b/test/simple_varinfo.jl @@ -0,0 +1,128 @@ +@testset "simple_varinfo.jl" begin + @testset "constructor & indexing" begin + @testset "NamedTuple" begin + svi = SimpleVarInfo(; m=1.0) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(; m=[1.0]) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(; m=(a=[1.0],)) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + + svi = SimpleVarInfo{Float32}(; m=1.0) + @test getlogp(svi) isa Float32 + + svi = SimpleVarInfo((m=1.0,), 1.0) + @test getlogp(svi) == 1.0 + end + + @testset "Dict" begin + svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + + svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) + # Now we only have a variable `m.a` which is subsumed by `m`, + # but we can't guarantee that we have the "entire" `m`. + @test !haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + end + end + + @testset "SimpleVarInfo on $(model.name)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # We might need to pre-allocate for the variable `m`, so we need + # to see whether this is the case. + m = model().m + svi_nt = if m isa AbstractArray + SimpleVarInfo((m=similar(m),)) + else + SimpleVarInfo() + end + svi_dict = SimpleVarInfo(VarInfo(model), Dict) + + @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + # Random seed is set in each `@testset`, so we need to sample + # a new realization for `m` here. + m = model().m + + ### Sampling ### + # Sample a new varinfo! + _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + + # If the `m[1]` varname doesn't exist, this is a univariate model. + # TODO: Find a better way of dealing with this that is not dependent + # on knowledge of internals of `model`. + isunivariate = !haskey(svi_new, @varname(m[1])) + + # Realization for `m` should be different wp. 1. + if isunivariate + @test svi_new[@varname(m)] != m + else + @test svi_new[@varname(m[1])] != m[1] + @test svi_new[@varname(m[2])] != m[2] + end + # Logjoint should be non-zero wp. 1. + @test getlogp(svi_new) != 0 + + ### Evaluation ### + # Sample some random testing values. + m_eval = if m isa AbstractArray + randn!(similar(m)) + else + randn(eltype(m)) + end + + # Update the realizations in `svi_new`. + svi_eval = if isunivariate + DynamicPPL.setindex!!(svi_new, m_eval, @varname(m)) + else + DynamicPPL.setindex!!(svi_new, m_eval, [@varname(m[1]), @varname(m[2])]) + end + # Reset the logp field. + svi_eval = DynamicPPL.resetlogp!!(svi_eval) + + # Compute `logjoint` using the varinfo. + logπ = logjoint(model, svi_eval) + # Extract the parameters from `svi_eval`. + m_vi = if isunivariate + svi_eval[@varname(m)] + else + svi_eval[[@varname(m[1]), @varname(m[2])]] + end + # These should not have changed. + @test m_vi == m_eval + # Compute the true `logjoint` and compare. + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_vi) + @test logπ ≈ logπ_true + end + end +end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 83c53ccd6..460d68ca3 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -17,17 +17,17 @@ lp = getlogp(vi) @test getlogp(threadsafe_vi) == lp - acclogp!(threadsafe_vi, 42) + acclogp!!(threadsafe_vi, 42) @test threadsafe_vi.logps[Threads.threadid()][] == 42 @test getlogp(vi) == lp @test getlogp(threadsafe_vi) == lp + 42 - resetlogp!(threadsafe_vi) + resetlogp!!(threadsafe_vi) @test iszero(getlogp(vi)) @test iszero(getlogp(threadsafe_vi)) @test all(iszero(x[]) for x in threadsafe_vi.logps) - setlogp!(threadsafe_vi, 42) + setlogp!!(threadsafe_vi, 42) @test getlogp(vi) == 42 @test getlogp(threadsafe_vi) == 42 @test all(iszero(x[]) for x in threadsafe_vi.logps) @@ -60,7 +60,7 @@ @time wthreads(x)(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe( + DynamicPPL.evaluate_threadsafe!!( wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -68,8 +68,8 @@ @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo - println(" evaluate_threadsafe:") - @time DynamicPPL.evaluate_threadsafe( + println(" evaluate_threadsafe!!:") + @time DynamicPPL.evaluate_threadsafe!!( wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -99,7 +99,7 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe( + DynamicPPL.evaluate_threadunsafe!!( wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -107,8 +107,8 @@ @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo - println(" evaluate_threadunsafe:") - @time DynamicPPL.evaluate_threadunsafe( + println(" evaluate_threadunsafe!!:") + @time DynamicPPL.evaluate_threadunsafe!!( wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), diff --git a/test/varinfo.jl b/test/varinfo.jl index ff05d4dd8..2f7816024 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -33,7 +33,7 @@ end @testset "Base" begin # Test Base functions: - # string, Symbol, ==, hash, in, keys, haskey, isempty, push!, empty!, + # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! csym = gensym() vn1 = @varname x[1][2] @@ -46,7 +46,7 @@ @test inspace(vn1, (:x,)) function test_base!(vi) - empty!(vi) + empty!!(vi) @test getlogp(vi) == 0 @test get_num_produce(vi) == 0 @@ -58,7 +58,7 @@ @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -75,9 +75,9 @@ @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!(vi) + empty!!(vi) @test isempty(vi) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) function test_inspace() space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) @@ -98,7 +98,7 @@ end vi = VarInfo() test_base!(vi) - test_base!(empty!(TypedVarInfo(vi))) + test_base!(empty!!(TypedVarInfo(vi))) end @testset "flags" begin # Test flag setting: @@ -109,7 +109,7 @@ r = rand(dist) gid = Selector() - push!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist, gid) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -122,7 +122,7 @@ end vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!(TypedVarInfo(vi))) + test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin vi = VarInfo() @@ -133,14 +133,14 @@ gid1 = Selector() gid2 = Selector(2, :HMC) - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.gids[meta.idcs[vn]] == Set([gid1]) setgid!(vi, gid2, vn) @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - vi = empty!(TypedVarInfo(vi)) + vi = empty!!(TypedVarInfo(vi)) meta = vi.metadata - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2])