From 3ab2beed11a07bbb2b7009e2956252dbf1998463 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:10:31 +0100 Subject: [PATCH 001/216] unwrap distributions and varnames at model-level --- src/compiler.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bef7d11c2..0be9d4d44 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -52,6 +52,45 @@ end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +""" + unwrap_right_vn(right, vn) +Return the unwrapped distribution on the right-hand side and variable name on the left-hand +side of a `~` expression such as `x ~ Normal()`. +This is used mainly to unwrap `NamedDist` distributions. +""" +unwrap_right_vn(right, vn) = right, vn +unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) + +""" + unwrap_right_left_vns(context, right, left, vns) +Return the unwrapped distributions on the right-hand side and values and variable names on the +left-hand side of a `.~` expression such as `x .~ Normal()`. +This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the +variables. +""" +unwrap_right_left_vns(right, left, vns) = right, left, vns +function unwrap_right_left_vns(right::NamedDist, left, vns) + return unwrap_right_left_vns(right.dist, left, right.name) +end +function unwrap_right_left_vns( + right::MultivariateDistribution, left::AbstractMatrix, vn::VarName +) + vns = map(axes(left, 2)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end +function unwrap_right_left_vns( + right::Union{Distribution,AbstractArray{<:Distribution}}, + left::AbstractArray, + vn::VarName, +) + vns = map(CartesianIndices(left)) do i + return VarName(vn, (vn.indexing..., Tuple(i))) + end + return unwrap_right_left_vns(right, left, vns) +end + ################# # Main Compiler # ################# @@ -264,8 +303,9 @@ function generate_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $vn, + $(DynamicPPL.unwrap_right_vn)( + $(DynamicPPL.check_tilde_rhs)($right), $vn + )..., $inds, __varinfo__, ) @@ -314,9 +354,9 @@ function generate_dot_tilde(left, right) __rng__, __context__, __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - $vn, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $left, $vn + )..., $inds, __varinfo__, ) From a549d1fc7b5b149b0fdf96754f366e825bcf31e9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:10:54 +0100 Subject: [PATCH 002/216] removed _tilde and renamed tilde_assume and others --- src/context_implementations.jl | 175 +++++++++++---------------------- 1 file changed, 60 insertions(+), 115 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afc5e4da3..f1977fe80 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,79 +18,72 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return _tilde(rng, sampler, right, vn, vi) +function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) + return assume(rng, sampler, right, vn, vi) end -function tilde(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) +function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) if ctx.vars !== nothing vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, right, vn, vi) + return assume(rng, sampler, right, vn, vi) end -function tilde(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) +function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return _tilde(rng, sampler, NoDist(right), vn, vi) + return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) +function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) + return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) end -function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) +function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) + return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) end """ - tilde_assume(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end -function _tilde(rng, sampler, right, vn::VarName, vi) - return assume(rng, sampler, right, vn, vi) -end -function _tilde(rng, sampler, right::NamedDist, vn::VarName, vi) - return _tilde(rng, sampler, right.dist, right.name, vi) -end - # observe -function tilde(ctx::DefaultContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) +function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) + return observe(sampler, right, left, vi) end -function tilde(ctx::PriorContext, sampler, right, left, vi) +function tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end -function tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _tilde(sampler, right, left, vi) +function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return observe(sampler, right, left, vi) end -function tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) +function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) end -function tilde(ctx::PrefixContext, sampler, right, left, vi) - return tilde(ctx.ctx, sampler, right, left, vi) +function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) + return tilde_observe(ctx.ctx, sampler, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end @@ -103,7 +96,7 @@ return the observed value. Falls back to `tilde(ctx, sampler, right, left, vi)`. """ -function tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, sampler, right, left, vi) logp = tilde(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left @@ -151,80 +144,44 @@ end # .~ functions # assume -function dot_tilde(rng, ctx::DefaultContext, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - return _dot_tilde(rng, sampler, dist, left, vns, vi) +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) + return dot_assume(rng, sampler, right, left, vns, vi) end -function dot_tilde(rng, ctx::LikelihoodContext, sampler, right, left, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) +function dot_tilde_assume(rng, ctx::LikelihoodContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} + if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, NoDist.(dist), left, vns, vi) + return dot_assume(rng, sampler, NoDist.(right), left, vns, vi) end -function dot_tilde(rng, ctx::MiniBatchContext, sampler, right, left, vn::VarName, inds, vi) - return dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) + return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) end -function dot_tilde(rng, ctx::PriorContext, sampler, right, left, vn::VarName, inds, vi) +function dot_tilde_assume(rng, ctx::PriorContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, getsym(vn)), inds) - vns, dist = get_vns_and_dist(right, var, vn) - set_val!(vi, vns, dist, var) + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) - else - vns, dist = get_vns_and_dist(right, left, vn) end - return _dot_tilde(rng, sampler, dist, left, vns, vi) + return dot_assume(rng, sampler, right, left, vns, vi) end """ - dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -function get_vns_and_dist(dist::NamedDist, var, vn::VarName) - return get_vns_and_dist(dist.dist, var, dist.name) -end -function get_vns_and_dist(dist::MultivariateDistribution, var::AbstractMatrix, vn::VarName) - getvn = i -> VarName(vn, (vn.indexing..., (Colon(), i))) - return getvn.(1:size(var, 2)), dist -end -function get_vns_and_dist( - dist::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vn::VarName -) - getvn = ind -> VarName(vn, (vn.indexing..., Tuple(ind))) - return getvn.(CartesianIndices(var)), dist -end - -function _dot_tilde(rng, sampler, right, left, vns::AbstractArray{<:VarName}, vi) - return dot_assume(rng, sampler, right, vns, left, vi) -end - # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( - rng, - sampler::AbstractSampler, - right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, - left::AbstractMatrix{>:AbstractVector}, - vn::AbstractVector{<:VarName}, - vi, -) - return throw(DimensionMismatch(AMBIGUITY_MSG)) -end - function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -348,61 +305,49 @@ function set_val!( end # observe -function dot_tilde(ctx::DefaultContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) +function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) end -function dot_tilde(ctx::PriorContext, sampler, right, left, vi) +function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) return 0 end -function dot_tilde(ctx::LikelihoodContext, sampler, right, left, vi) - return _dot_tilde(sampler, right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) end -function dot_tilde(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end """ - dot_tilde_observe(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, sampler, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. """ -function dot_tilde_observe(ctx, sampler, right, left, vi) - logp = dot_tilde(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, sampler, right, left, vi) + logp = dot_tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end -function _dot_tilde(sampler, right, left::AbstractArray, vi) - return dot_observe(sampler, right, left, vi) -end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function _dot_tilde( - sampler::AbstractSampler, - right::Union{MultivariateDistribution,AbstractVector{<:MultivariateDistribution}}, - left::AbstractMatrix{>:AbstractVector}, - vi, -) - return throw(DimensionMismatch(AMBIGUITY_MSG)) -end - function dot_observe( spl::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, From e0f77bc67c4a51f1eab941323e30634a0e8c5c02 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:11:33 +0100 Subject: [PATCH 003/216] formatting --- src/context_implementations.jl | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f1977fe80..9f78a10ec 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -147,7 +147,16 @@ end function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) return dot_assume(rng, sampler, right, left, vns, vi) end -function dot_tilde_assume(rng, ctx::LikelihoodContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} +function dot_tilde_assume( + rng, + ctx::LikelihoodContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) @@ -158,7 +167,16 @@ end function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) end -function dot_tilde_assume(rng, ctx::PriorContext, sampler, right, left, vns::AbstractArray{<:VarName{sym}}, inds, vi) where {sym} +function dot_tilde_assume( + rng, + ctx::PriorContext, + sampler, + right, + left, + vns::AbstractArray{<:VarName{sym}}, + inds, + vi, +) where {sym} if ctx.vars !== nothing var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) From 8e4fa91db88448d2d0a73fbe3fc86b7644b19223 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 20:39:48 +0100 Subject: [PATCH 004/216] updated compiler for new tilde-methods --- src/compiler.jl | 14 +++++++------- src/context_implementations.jl | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0be9d4d44..20d8bf8ef 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -62,7 +62,7 @@ unwrap_right_vn(right, vn) = right, vn unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) """ - unwrap_right_left_vns(context, right, left, vns) + unwrap_right_left_vns(right, left, vns) Return the unwrapped distributions on the right-hand side and values and variable names on the left-hand side of a `.~` expression such as `x .~ Normal()`. This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the @@ -281,7 +281,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -299,7 +299,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume)( + $left = $(DynamicPPL.tilde_assume!)( __rng__, __context__, __sampler__, @@ -310,7 +310,7 @@ function generate_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.tilde_observe)( + $(DynamicPPL.tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -332,7 +332,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if !(left isa Symbol || left isa Expr) return quote - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), @@ -350,7 +350,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume)( + $left .= $(DynamicPPL.dot_tilde_assume!)( __rng__, __context__, __sampler__, @@ -361,7 +361,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe)( + $(DynamicPPL.dot_tilde_observe!)( __context__, __sampler__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f78a10ec..8d4c5c2e2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -97,7 +97,7 @@ return the observed value. Falls back to `tilde(ctx, sampler, right, left, vi)`. """ function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde(ctx, sampler, right, left, vi) + logp = tilde_observe(ctx, sampler, right, left, vi) acclogp!(vi, logp) return left end From 1c9a2d58e1e4e4965d8b098f8c43027f867887c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 31 May 2021 21:02:45 +0100 Subject: [PATCH 005/216] fixed calls to dot_assume --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8d4c5c2e2..0698b6cdf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -145,7 +145,7 @@ end # assume function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) - return dot_assume(rng, sampler, right, left, vns, vi) + return dot_assume(rng, sampler, right, vns, left, vi) end function dot_tilde_assume( rng, @@ -162,7 +162,7 @@ function dot_tilde_assume( set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) end - return dot_assume(rng, sampler, NoDist.(right), left, vns, vi) + return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) @@ -182,7 +182,7 @@ function dot_tilde_assume( set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) end - return dot_assume(rng, sampler, right, left, vns, vi) + return dot_assume(rng, sampler, right, vns, left, vi) end """ From d70e1be46058e912121b41dd0e2f0724c57474c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:21:20 +0100 Subject: [PATCH 006/216] added sampling context and unwrap_childcontext --- src/contexts.jl | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..1ee43f2b2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,39 @@ +""" + unwrap_childcontext(context::AbstractContext) + +Return a tuple of the child context of a `context`, or `nothing` if the context does +not wrap any other context, and a function `f(c::AbstractContext)` that constructs +an instance of `context` in which the child context is replaced with `c`. + +Falls back to `(nothing, _ -> context)`. +""" +function unwrap_childcontext(context::AbstractContext) + reconstruct_context(@nospecialize(x)) = context + return nothing, reconstruct_context +end + +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + +function unwrap_childcontext(context::SamplingContext) + child = context.context + function reconstruct_samplingcontext(c::AbstractContext) + return SamplingContext(context.rng, context.sampler, c) + end + return child, reconstruct_samplingcontext +end + """ struct DefaultContext <: AbstractContext end @@ -53,6 +89,25 @@ function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints / batch_size) end +function unwrap_childcontext(context::MiniBatchContext) + child = context.context + function reconstruct_minibatchcontext(c::AbstractContext) + return MiniBatchContext(c, context.loglike_scalar) + end + return child, reconstruct_minibatchcontext +end + +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext ctx::C end @@ -81,3 +136,11 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +function unwrap_childcontext(context::PrefixContext{P}) where {P} + child = context.context + function reconstruct_prefixcontext(c::AbstractContext) + return PrefixContext{P}(c) + end + return child, reconstruct_prefixcontext +end From f74399031eca9b5eaab824c16416a825c489f59c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:31 +0100 Subject: [PATCH 007/216] updated tilde methods --- src/context_implementations.jl | 451 +++++++++++++++++++++++++-------- 1 file changed, 352 insertions(+), 99 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..8aa8ddfca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,28 +18,103 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return assume(rng, sampler, right, vn, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) + else + tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) + end +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, inds, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, right, vn, vi) + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, NoDist(right), vn, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) + return assume(rng, sampler, NoDist(right), vn, inds, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) end """ @@ -50,27 +125,76 @@ accumulate the log probability, and return the sampled value. Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vname, vinds, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vi) +end + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return tilde_observe( + context.ctx, right, left, prefix(context, vname), vinds, vi + ) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.ctx, right, left, vi) end """ @@ -112,77 +236,179 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, inds, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + inds, + vi, ) + # Always overwrite the parameters with new ones. + r = init(rng, dist, sampler) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) + push!(vi, vn, r, dist, sampler) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) + else + dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + end +end + +# `DefaultContext` +function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) + return dot_assume(right, vns, left, vi) +end + +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + value, logp = dot_assume(right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -193,13 +419,26 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + inds, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(zip(vns, eachcol(var))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return var, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -214,6 +453,19 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + inds, + vi, +) + # Make sure `var` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -323,18 +575,38 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) + +# `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +end +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vi) +end """ dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) @@ -366,41 +638,22 @@ function dot_tilde_observe!(ctx, sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - value::AbstractMatrix, - vi, -) +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Distribution, - value::AbstractArray, - vi, -) +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end + From 3d2e7e2b4dfb2462e6732eb3f0b3ac5494d1696f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:48 +0100 Subject: [PATCH 008/216] updated model call signature --- src/model.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7189b590e..250b89721 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(SamplingContext(rng, sampler, context), varinfo) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, sampler, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, sampler, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, sampler, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, sampler, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, sampler, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, sampler, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, sampler, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, sampler, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, sampler, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, sampler, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end From 4f1d39694083bae41ac7e94cbb19640639512879 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:24:04 +0100 Subject: [PATCH 009/216] updated compiler --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..bc906f58c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,10 +394,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -407,7 +405,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. From b187d74efcfb5b9482f39022252477b5a0bc2cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:27:37 +0100 Subject: [PATCH 010/216] formatting --- src/context_implementations.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8aa8ddfca..a0bf0381a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,9 @@ end # Leaf contexts tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, inds, vi) end @@ -184,14 +186,13 @@ function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) + return context.loglike_scalar * + tilde_observe(context.ctx, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe( - context.ctx, right, left, prefix(context, vname), vinds, vi - ) + return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.ctx, right, left, vi) @@ -296,7 +297,9 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) return if child_of_c === nothing dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) else - dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + dot_tilde_assume( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi + ) end end @@ -590,14 +593,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(right, left, vi) +end # `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * + dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -656,4 +662,3 @@ function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end - From ee99f8ce5676c5cb571417e6dd4b3d9570922bf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:30:54 +0100 Subject: [PATCH 011/216] added getsym for vectors --- src/varname.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..40c5c25e9 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,7 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From c4845d08b34b58bc30762b4be944c8924c9794e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:35:12 +0100 Subject: [PATCH 012/216] Update src/varname.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varname.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index 40c5c25e9..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -40,6 +40,5 @@ Possibly existing indices of `varname` are neglected. return s in missings end - # HACK: Type-piracy. Is this really the way to go? AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From a0c05f39315c93d9ed43cefe227255310172f339 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:43:42 +0100 Subject: [PATCH 013/216] fixed some signatures for Model --- src/model.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/model.jl b/src/model.jl index 250b89721..8d353d2de 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,7 +88,7 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return model(SamplingContext(rng, sampler, context), varinfo) + return model(varinfo, SamplingContext(rng, sampler, context)) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) @@ -115,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -124,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -140,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 307cd7e1f3a1a02bd79284ccfc640848961e2bd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:49:44 +0100 Subject: [PATCH 014/216] fixed a method call --- src/model.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8d353d2de..3a01f9bf3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,9 +94,9 @@ end (model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end @@ -151,7 +151,7 @@ end """ _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context From 597277119922a24d0783b78134a8f541eb05dc0e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:00:49 +0100 Subject: [PATCH 015/216] fixed method signatures --- src/compiler.jl | 8 ------ src/context_implementations.jl | 48 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bc906f58c..8201a82f4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -283,7 +283,6 @@ function generate_tilde(left, right) return quote $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -300,9 +299,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -312,7 +309,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -334,7 +330,6 @@ function generate_dot_tilde(left, right) return quote $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -351,9 +346,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -363,7 +356,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a0bf0381a..d6ff3b5bd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(ctx, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end @@ -415,15 +415,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -615,30 +615,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end From c4ecd0e676ab58aa13ef9995289d4368868d212c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:08:56 +0100 Subject: [PATCH 016/216] sort of fixed the matchingvalue functionality for model --- src/model.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3a01f9bf3..2d74949c1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -157,7 +157,10 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ From a34b51cd60fffe8ff45903153550ff3486680f91 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 03:36:55 +0100 Subject: [PATCH 017/216] formatting --- src/compiler.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8201a82f4..dc70ae267 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -282,10 +282,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -329,10 +326,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end From b89ff7e3a3b6c1140b155ceb6997a2cabe5479cc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 04:09:53 +0100 Subject: [PATCH 018/216] removed redundant _tilde method --- src/context_implementations.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..b5a0fd923 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -102,8 +102,6 @@ function tilde_observe!(ctx, sampler, right, left, vi) return left end -_tilde(sampler, right, left, vi) = observe(sampler, right, left, vi) - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end From e4a2cf81154e44b23af16e473d1361cf11225fc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:02:41 +0100 Subject: [PATCH 019/216] removed left-over acclogp! that should not be here anymore --- src/context_implementations.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 42a336479..b088577f5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -345,16 +345,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(NoDist.(right), left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) end # `PriorContext` @@ -390,16 +386,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - value, logp = dot_assume(right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(right, left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, right, left, vn, inds, vi) end # `MiniBatchContext` From 7605785fff5407d558dd920720180efc0e41d885 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:04:29 +0100 Subject: [PATCH 020/216] export SamplingContext --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..3ad30972c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,6 +76,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, From 354ac52b0d2115b2df7d437407b98140a208ba5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:25 +0100 Subject: [PATCH 021/216] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 64 +++++++++++++++++----------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b088577f5..f859b4619 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(ctx, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, right, vn, inds, vi) - value, logp = tilde_assume(ctx, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, right, left, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -302,11 +302,11 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -405,15 +405,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(ctx, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -583,17 +583,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(right, left, vi) end # `MiniBatchContext` -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * - dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * + dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -605,30 +605,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, right, left, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end From b7a2b3b5b5483eb11741a9a2c7b3135abaecbcad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:46 +0100 Subject: [PATCH 022/216] formatting --- src/context_implementations.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f859b4619..a8f279804 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -591,7 +591,9 @@ end function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::MiniBatchContext, sampler, right, left, vname, vinds, vi +) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end From 9e0fc9a9eecb6f74d54e0b4a1fad9cf94c0b41eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:39:41 +0100 Subject: [PATCH 023/216] use context instead of ctx for variables --- src/contexts.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1ee43f2b2..8598fb633 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -71,7 +71,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -82,11 +82,11 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end function unwrap_childcontext(context::MiniBatchContext) @@ -109,23 +109,23 @@ unique. See also: [`@submodel`](@ref) """ struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end From 7a4a1a38ca6895c401e366e9b2707117b5ce36e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:40:18 +0100 Subject: [PATCH 024/216] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 47 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a8f279804..e66501aee 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -27,9 +27,9 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -112,11 +112,11 @@ function tilde_assume( end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end """ @@ -138,8 +138,8 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -159,8 +159,8 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -183,19 +183,19 @@ tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) return context.loglike_scalar * - tilde_observe(context.ctx, right, left, vname, vinds, vi) + tilde_observe(context.context, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.ctx, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ @@ -283,9 +283,9 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -396,12 +396,12 @@ end # `MiniBatchContext` function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -574,10 +574,10 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, right, left, vname, vinds, vi) end # Leaf contexts @@ -589,21 +589,24 @@ end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) + return context.loglike_scalar * + dot_tilde_observe(context.context, sampler, right, left, vi) end function dot_tilde_observe( context::MiniBatchContext, sampler, right, left, vname, vinds, vi ) return context.loglike_scalar * - dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) end # `PrefixContext` function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return dot_tilde_observe( + context.context, right, left, prefix(context, vname), vinds, vi + ) end function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ From 7899473512e85089f0a6497823e804c0a2a7e12c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:27 +0100 Subject: [PATCH 025/216] Update src/compiler.jl Co-authored-by: David Widmann --- src/compiler.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc70ae267..8734b72ed 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -392,12 +392,6 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = quote - # in case someone accessed these - if __context__ isa $(DynamicPPL.SamplingContext) - __rng__ = __context__.rng - __sampler__ = __context__.sampler - end - $(modelinfo[:body]) end From 1630476c742f82e3f45096923e39a4b2da6150a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:41 +0100 Subject: [PATCH 026/216] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e66501aee..4fd787c86 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -138,6 +138,7 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 6892d2b1276ef92f6765c3272673ac7a58682465 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:37:22 +0100 Subject: [PATCH 027/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4fd787c86..5647cd5fc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -160,6 +160,7 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 13da1b473654cb06cf621a99bc7a1904e3865102 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:00:10 +0100 Subject: [PATCH 028/216] added some whitespace to some docstrings --- src/compiler.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..2e368d32b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -54,8 +54,10 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x """ unwrap_right_vn(right, vn) + Return the unwrapped distribution on the right-hand side and variable name on the left-hand side of a `~` expression such as `x ~ Normal()`. + This is used mainly to unwrap `NamedDist` distributions. """ unwrap_right_vn(right, vn) = right, vn @@ -63,8 +65,10 @@ unwrap_right_vn(right::NamedDist, vn) = unwrap_right_vn(right.dist, right.name) """ unwrap_right_left_vns(right, left, vns) + Return the unwrapped distributions on the right-hand side and values and variable names on the left-hand side of a `.~` expression such as `x .~ Normal()`. + This is used mainly to unwrap `NamedDist` distributions and adjust the indices of the variables. """ From d76e5b3d188596008d51093b2dec9b6cf3d725c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:23:12 +0100 Subject: [PATCH 029/216] deprecated tilde and dot_tilde plus exported new versions --- src/DynamicPPL.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..319798780 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -83,10 +83,12 @@ export AbstractVarInfo, PrefixContext, assume, dot_assume, - observer, + observe, dot_observe, - tilde, - dot_tilde, + tilde_assume, + tilde_observe, + dot_tilde_assume, + dot_tilde_observe, # Pseudo distributions NamedDist, NoDist, @@ -128,4 +130,11 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# Deprecations. +@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +@deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) + +@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe(ctx, sampler, right, left, vi) + end # module From 805966979870ece251bc5304f20815a18a685ae1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 02:29:09 +0100 Subject: [PATCH 030/216] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/DynamicPPL.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 319798780..0eda04b28 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -131,10 +131,16 @@ include("loglikelihoods.jl") include("submodel_macro.jl") # Deprecations. -@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume( + rng, ctx, sampler, right, vn, inds, vi +) @deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) -@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) -@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe(ctx, sampler, right, left, vi) +@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume( + rng, ctx, sampler, right, left, vn, inds, vi +) +@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe( + ctx, sampler, right, left, vi +) end # module From 43ef8d1a659af36312d61e96e622f4cab39cf21e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 23:42:16 +0100 Subject: [PATCH 031/216] minor version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 96c60a14f..76b94cb49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.11.0" +version = "0.11.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 1015f0e3a248aacb3039dd4adee670504de4412f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 13:56:32 +0100 Subject: [PATCH 032/216] added impl of matchingvalue for contexts --- src/compiler.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e647df99c..352d46418 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -435,8 +435,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) -Convert the `value` to the correct type for the `sampler` and the `vi` object. +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. + +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -453,6 +457,13 @@ function matchingvalue(sampler, vi, value) end matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end + """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} From 23c86a76d27248f3b2c71eb6f72ea85439dd62b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:32:06 +0100 Subject: [PATCH 033/216] reverted the change that makes assume always resample --- src/context_implementations.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d1fd7b0ba..e5e89ca55 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -255,15 +255,23 @@ function assume( inds, vi, ) - # Always overwrite the parameters with new ones. - r = init(rng, dist, sampler) if haskey(vi, vn) - vi[vn] = vectorize(dist, r) - setorder!(vi, vn, get_num_produce(vi)) + # Always overwrite the parameters with new ones for `SampleFromUniform`. + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = init(rng, dist, sampler) + vi[vn] = vectorize(dist, r) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + r = vi[vn] + end else + r = init(rng, dist, sampler) push!(vi, vn, r, dist, sampler) + settrans!(vi, false, vn) end - settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end From 17f5abe88edefca1a1307987300197d19210da34 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:39:20 +0100 Subject: [PATCH 034/216] removed the inds arguments from assume and dot_assume to stay non-breaking --- src/context_implementations.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e5e89ca55..50e3a8e4b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -44,11 +44,11 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) function tilde_assume( rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) @@ -74,10 +74,10 @@ function tilde_assume( return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) + return assume(right, vn, vi) end function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) @@ -103,12 +103,12 @@ function tilde_assume( return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) + return assume(NoDist(right), vn, vi) end function tilde_assume( rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, NoDist(right), vn, inds, vi) + return assume(rng, sampler, NoDist(right), vn, vi) end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) @@ -238,7 +238,7 @@ function observe(spl::Sampler, weight) end # fallback without sampler -function assume(dist::Distribution, vn::VarName, inds, vi) +function assume(dist::Distribution, vn::VarName, vi) if !haskey(vi, vn) error("variable $vn does not exist") end @@ -252,7 +252,6 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - inds, vi, ) if haskey(vi, vn) @@ -355,12 +354,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - return dot_assume(NoDist.(right), left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) end # `PriorContext` @@ -396,12 +395,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, inds, vi) + return dot_assume(rng, sampler, right, left, vn, vi) end # `MiniBatchContext` @@ -433,7 +432,6 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, - inds, vi, ) @assert length(dist) == size(var, 1) @@ -460,7 +458,6 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, - inds, vi, ) # Make sure `var` is not a matrix for multivariate distributions From dbd61f04d1e5719633e36c525fc04b9190659b81 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 7 Jun 2021 15:32:41 +0100 Subject: [PATCH 035/216] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..16704a814 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -429,10 +429,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From b10ba3f17f9f8f493893f8728e5f9aaf160dfd73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:04:28 +0100 Subject: [PATCH 036/216] added missing sampler arg to tilde_observe --- src/context_implementations.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..a65336670 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -170,18 +170,19 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vi + ) end - return tilde_observe(fallback_context, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) From bc5029f45c20a8dfa777b8e370c7af16797c890c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:28 +0100 Subject: [PATCH 037/216] added missing sampler argument in dot_tilde_observe --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a65336670..04ddac2c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -583,7 +583,7 @@ probability, and return the observed value for a context associated with a sampl Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts From 7eac33dc6072023a0ac4ab59cb26f6e8b52efb32 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:49 +0100 Subject: [PATCH 038/216] fixed order of arguments in some dot_assume calls --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 04ddac2c5..77a7475a3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -360,7 +360,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` @@ -401,7 +401,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, vi) + return dot_assume(rng, sampler, right, vn, left, vi) end # `MiniBatchContext` From 85994813e350ea3e132e6f2dc0fc3d4896b4e641 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:01 +0100 Subject: [PATCH 039/216] formatting --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77a7475a3..d7d51c72a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -430,10 +430,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From 90a8c4562e10153ef9323200c59d417411ec009b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:22 +0100 Subject: [PATCH 040/216] formatting --- src/context_implementations.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7d51c72a..431af5a5e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -173,9 +173,7 @@ function tilde_observe(context::SamplingContext, right, left, vi) return if child_of_c === nothing tilde_observe(c, context.sampler, right, left, vi) else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vi - ) + tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) end end From f9d4ff83d2ed6b4e67426e636c7c1006a66cafa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:36 +0100 Subject: [PATCH 041/216] added missing sampler argument in tilde_observe for SamplingContext --- src/context_implementations.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 431af5a5e..5dce075a2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -148,12 +148,13 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vname, vinds, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi + ) end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) end """ From e424fe7b45821382b7ceb35d5294689730081334 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:08:15 +0100 Subject: [PATCH 042/216] added missing word in a docstring --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5dce075a2..b67833d62 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -614,7 +614,7 @@ end """ dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable From 70957d27c66a3a1ce7b3ac751e68709531b5e548 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:03:50 +0100 Subject: [PATCH 043/216] updated submodel macro --- src/submodel_macro.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..070a5aa4c 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,8 @@ macro submodel(expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), $(esc(:__context__)), ) end @@ -13,10 +11,8 @@ end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end From d00cdcfb0ceed552a42abaadac2a8d4413e4bd4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:58:46 +0100 Subject: [PATCH 044/216] removed unwrap_childcontext and related since its not needed for this PR --- src/context_implementations.jl | 130 +++++++++++++++------------------ src/contexts.jl | 38 ---------- 2 files changed, 58 insertions(+), 110 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b67833d62..ae4d631ae 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -29,18 +29,9 @@ Falls back to ```julia tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) - end + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) end # Leaf contexts @@ -115,10 +106,18 @@ function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) return tilde_assume(context.context, right, vn, inds, vi) end +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -139,22 +138,12 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vname, vinds, vi) - else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi - ) - end + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) end """ @@ -162,39 +151,31 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vi) - else - tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) - end + return tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * - tilde_observe(context.context, right, left, vname, vinds, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) @@ -294,25 +275,16 @@ Falls back to ```julia dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) - end + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) @@ -408,11 +380,23 @@ function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -583,30 +567,23 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) return dot_observe(right, left, vi) end +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end # `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vi) -end -function dot_tilde_observe( - context::MiniBatchContext, sampler, right, left, vname, vinds, vi -) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) end # `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe( - context.context, right, left, prefix(context, vname), vinds, vi - ) -end function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end @@ -641,18 +618,27 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) + return dot_observe(dist, value, vi) +end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" diff --git a/src/contexts.jl b/src/contexts.jl index 6daa18776..8093c88f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,17 +1,3 @@ -""" - unwrap_childcontext(context::AbstractContext) - -Return a tuple of the child context of a `context`, or `nothing` if the context does -not wrap any other context, and a function `f(c::AbstractContext)` that constructs -an instance of `context` in which the child context is replaced with `c`. - -Falls back to `(nothing, _ -> context)`. -""" -function unwrap_childcontext(context::AbstractContext) - reconstruct_context(@nospecialize(x)) = context - return nothing, reconstruct_context -end - """ SamplingContext(rng, sampler, context) @@ -26,14 +12,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end -function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) - return SamplingContext(context.rng, context.sampler, c) - end - return child, reconstruct_samplingcontext -end - """ struct DefaultContext <: AbstractContext end @@ -89,14 +67,6 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end -function unwrap_childcontext(context::MiniBatchContext) - child = context.context - function reconstruct_minibatchcontext(c::AbstractContext) - return MiniBatchContext(c, context.loglike_scalar) - end - return child, reconstruct_minibatchcontext -end - """ PrefixContext{Prefix}(context) @@ -136,11 +106,3 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context - function reconstruct_prefixcontext(c::AbstractContext) - return PrefixContext{P}(c) - end - return child, reconstruct_prefixcontext -end From 639fd6ebe36bee13fc2372915e2480858012612c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:05 +0100 Subject: [PATCH 045/216] updated submodel macro --- src/submodel_macro.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 070a5aa4c..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,6 @@ macro submodel(expr) return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end From c9a06fb46c8a9d4f184e2b64cfd9c119d68c5cc3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:15 +0100 Subject: [PATCH 046/216] fixed evaluation implementations of dot_assume --- src/context_implementations.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ae4d631ae..6befa1f6d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -416,10 +416,17 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) - lp = sum(zip(vns, eachcol(var))) do vn, ri + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return var, lp + return r, lp end function dot_assume( rng, @@ -441,9 +448,15 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - # Make sure `var` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) - return var, lp + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp end function dot_assume( From 2fe5f4016cb66ff5da3af57415dd096843363a1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:45 +0100 Subject: [PATCH 047/216] updated pointwise_loglikelihoods and related --- src/loglikelihoods.jl | 99 ++++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6fca717c6 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end From b532ca690c254b12732dcdb5f7c9a67577a763cf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:08 +0100 Subject: [PATCH 048/216] added proper tests for pointwise_loglikelihoods --- test/loglikelihoods.jl | 122 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + 2 files changed, 124 insertions(+) create mode 100644 test/loglikelihoods.jl diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..5c1fdc082 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,122 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x = 10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x = 10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) + + +@testset "loglikelihoods.jl" begin + for m in mean_of_mean_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin From 4e2274e7abffc1058f370d32f4cc93e21df1d1f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:24 +0100 Subject: [PATCH 049/216] updated DPPL tests to reflect recent changes --- test/compiler.jl | 11 +++++------ test/threadsafe.jl | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..7a2bdd039 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,14 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @model function wothreads(x) @@ -96,14 +96,14 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) end end From ef6da4377024f68bbd41683857e37fade88498f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:02:00 +0100 Subject: [PATCH 050/216] bump minor version since this will be breaking --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f7a5ba10d..db9f26b04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.11.2" +version = "0.12.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 10899f370c5335a53b0212146cddaf69ad43e62c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:03 +0100 Subject: [PATCH 051/216] formatting --- src/context_implementations.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6befa1f6d..77cbc0fb2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -581,7 +581,9 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function dot_tilde_observe(context::LikelihoodContext, right, left, vi) @@ -631,7 +633,12 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + value::AbstractMatrix, + vi, +) return dot_observe(dist, value, vi) end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @@ -640,7 +647,12 @@ function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::Distribution, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::Distribution, value::AbstractArray, vi) @@ -649,7 +661,12 @@ function dot_observe(dists::Distribution, value::AbstractArray, vi) @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::AbstractArray{<:Distribution}, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) From 1f21ce4158f5ed256b0435eafdc08acf37b8aece Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:47 +0100 Subject: [PATCH 052/216] formatting --- test/loglikelihoods.jl | 33 +++++++++++++++++---------------- test/threadsafe.jl | 16 ++++++++++++---- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 5c1fdc082..4cc7325b8 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,28 +1,28 @@ # A collection of models for which the mean-of-means for the posterior should # be same. -@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) m[i] ~ Normal() end - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo3(x = 10 * ones(2)) +@model function gdemo3(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -33,10 +33,10 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function gdemo5(x = 10 * ones(1)) +@model function gdemo5(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() - x .~ Normal(m, 0.5) + return x .~ Normal(m, 0.5) end # @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} @@ -45,7 +45,7 @@ end # [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) # end -@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing m = TV(undef, 2) m .~ Normal() @@ -60,7 +60,7 @@ end # [10.0, ] .~ Normal(m, 0.5) # end -@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) m .~ Normal() @@ -76,10 +76,10 @@ end end @model function _likelihood_dot_observe(m, x) - x ~ MvNormal(m, 0.5 * ones(length(m))) + return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -87,8 +87,9 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) - +const mean_of_mean_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) @testset "loglikelihoods.jl" begin for m in mean_of_mean_models @@ -97,7 +98,7 @@ const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), g vns = vi.metadata.m.vns if length(vns) == 1 && length(vi[vns[1]]) == 1 # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + DynamicPPL.setval!(vi, [1.0], ["m"]) else DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 7a2bdd039..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end From 70045061b0730bac5e76ed3c6b981511734b21f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:25:56 +0100 Subject: [PATCH 053/216] renamed mean_of_mean_models used in tests --- test/loglikelihoods.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4cc7325b8..74fb88d70 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -87,12 +87,12 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = ( +const gdemo_models = ( gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() ) @testset "loglikelihoods.jl" begin - for m in mean_of_mean_models + for m in gdemo_models vi = VarInfo(m) vns = vi.metadata.m.vns From fa6c4d6aed0312b69d054c3b46f7f438b69810c3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 08:05:29 +0100 Subject: [PATCH 054/216] bumped dppl version in integration tests --- test/turing/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a4f68621d..67b8d5645 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.11" +DynamicPPL = "0.12" Turing = "0.15, 0.16" julia = "1.3" From 684d829b437e546a02d588d2204082f0be7df0ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:23:07 +0100 Subject: [PATCH 055/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/compiler.jl | 4 +--- src/context_implementations.jl | 13 +++++++++---- src/contexts.jl | 2 +- src/loglikelihoods.jl | 2 +- src/model.jl | 7 ++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..2fa94bcd9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -395,9 +395,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = quote - $(modelinfo[:body]) - end + evaluatordef[:body] = modelinfo[:body] ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77cbc0fb2..259a7c1c3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -124,7 +124,8 @@ end Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(context, right, vn, inds, vi)`. +By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_assume!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) @@ -138,7 +139,10 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. +Falls back to +```julia +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +``` """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) return tilde_observe( @@ -151,7 +155,7 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)`. +Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) return tilde_observe(context.context, context.sampler, right, left, vi) @@ -202,7 +206,8 @@ end Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(context, right, left, vi)`. +By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_observe!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 8093c88f3..05ad8df0d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -4,7 +4,7 @@ Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6fca717c6..6c66e4ec4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -87,7 +87,7 @@ end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe(context.context, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. diff --git a/src/model.jl b/src/model.jl index 2d74949c1..9ec047a44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -156,11 +156,8 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return quote - sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() - model.f(model, varinfo, context, $(unwrap_args...)) - end + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 07bb28416adccab4f4783d0d708c9645f49be327 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:29:54 +0100 Subject: [PATCH 056/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 259a7c1c3..6833a7856 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -191,13 +191,11 @@ end Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function tilde_observe!(context, right, left, vname, vinds, vi) - logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return tilde_observe!(context, right, left, vi) end """ @@ -578,7 +576,7 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) return dot_tilde_observe(context.context, context.sampler, right, left, vi) @@ -614,13 +612,11 @@ end Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function dot_tilde_observe!(context, right, left, vn, inds, vi) - logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return dot_tilde_observe!(context, right, left, vi) end """ From c7c6a3c066e1c1d229edbe18fd8a1f7a4e8a9fc6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:56:57 +0100 Subject: [PATCH 057/216] fixed ambiguity error --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..3924eae95 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -455,7 +455,9 @@ function matchingvalue(sampler, vi, value) return value end end -matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) + return get_matching_type(sampler, vi, value) +end function matchingvalue(context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value) From 06d319c539a083bb2671c7aa146659cdcf638b16 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:36:52 +0000 Subject: [PATCH 058/216] Introduction of `SamplingContext`: keeping it simple (#259) This is #253 but the only motivation here is to get `SamplingContext` in, nothing relating to interactions with other contexts, etc. Co-authored-by: Hong Ge --- src/DynamicPPL.jl | 1 + src/compiler.jl | 37 ++- src/context_implementations.jl | 466 ++++++++++++++++++++++++++------- src/contexts.jl | 45 +++- src/loglikelihoods.jl | 99 ++++--- src/model.jl | 38 +-- src/submodel_macro.jl | 10 +- src/varname.jl | 3 + test/compiler.jl | 11 +- test/loglikelihoods.jl | 123 +++++++++ test/runtests.jl | 2 + test/threadsafe.jl | 16 +- test/turing/Project.toml | 2 +- 13 files changed, 653 insertions(+), 200 deletions(-) create mode 100644 test/loglikelihoods.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9eb4d9675..914c0e12b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -75,6 +75,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, diff --git a/src/compiler.jl b/src/compiler.jl index 2e368d32b..7c812fb54 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -286,11 +286,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -304,9 +300,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -316,7 +310,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -337,11 +330,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - __sampler__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -355,9 +344,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -367,7 +354,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -398,10 +384,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -449,8 +433,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) + +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. -Convert the `value` to the correct type for the `sampler` and the `vi` object. +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -465,7 +453,16 @@ function matchingvalue(sampler, vi, value) return value end end -matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) + return get_matching_type(sampler, vi, value) +end + +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 60df298b5..6833a7856 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,86 +18,197 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +``` +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) return assume(rng, sampler, right, vn, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +end + +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +probability of `vi` with the returned value. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. + +Falls back to +```julia +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +``` +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) +end + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. + +Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + return tilde_observe(context.context, context.sampler, right, left, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left +function tilde_observe!(context, right, left, vname, vinds, vi) + return tilde_observe!(context, right, left, vi) end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log +probability of `vi` with the returned value. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -110,14 +221,28 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + vi, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) + r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) @@ -125,79 +250,187 @@ function assume( r = vi[vn] end else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) + r = init(rng, dist, sampler) + push!(vi, vn, r, dist, sampler) settrans!(vi, false, vn) end + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +``` +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) +end + +# `DefaultContext` +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) +end + +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + return dot_assume(rng, sampler, right, vn, left, vi) +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +end + +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return r, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -211,6 +444,24 @@ function dot_assume( lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) return r, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -319,84 +570,109 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) +end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 +dot_tilde_observe(::PriorContext, right, left, vi) = 0 +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) + return dot_observe(right, left, vi) end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `MiniBatchContext` +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) - acclogp!(vi, logp) - return left +function dot_tilde_observe!(context, right, left, vn, inds, vi) + return dot_tilde_observe!(context, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi, ) + return dot_observe(dist, value, vi) +end +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, + ::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi, ) + return dot_observe(dists, value, vi) +end +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end diff --git a/src/contexts.jl b/src/contexts.jl index 2c23531c6..05ad8df0d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,17 @@ +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + """ struct DefaultContext <: AbstractContext end @@ -35,7 +49,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -46,31 +60,42 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6c66e4ec4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe!(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end diff --git a/src/model.jl b/src/model.jl index 7189b590e..9ec047a44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(varinfo, SamplingContext(rng, sampler, context)) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,22 +1,14 @@ macro submodel(expr) return quote - _evaluate( - $(esc(:__rng__)), - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__sampler__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,6 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..74fb88d70 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,123 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x=10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + return x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x=10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + return x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + return x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const gdemo_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) + +@testset "loglikelihoods.jl" begin + for m in gdemo_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0], ["m"]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a4f68621d..67b8d5645 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.11" +DynamicPPL = "0.12" Turing = "0.15, 0.16" julia = "1.3" From cb996c6fd002d88ec825bb0d9ca4fd428902a86f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 14:40:33 +0100 Subject: [PATCH 059/216] Update src/DynamicPPL.jl Co-authored-by: David Widmann --- src/DynamicPPL.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 914c0e12b..a46c941a1 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -130,17 +130,4 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -# Deprecations. -@deprecate tilde(rng, ctx, sampler, right, vn, inds, vi) tilde_assume( - rng, ctx, sampler, right, vn, inds, vi -) -@deprecate tilde(ctx, sampler, right, left, vi) tilde_observe(ctx, sampler, right, left, vi) - -@deprecate dot_tilde(rng, ctx, sampler, right, left, vn, inds, vi) dot_tilde_assume( - rng, ctx, sampler, right, left, vn, inds, vi -) -@deprecate dot_tilde(ctx, sampler, right, left, vi) dot_tilde_observe( - ctx, sampler, right, left, vi -) - end # module From 03c9285c23b03441700e172307f3b33913a665bd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:06 +0100 Subject: [PATCH 060/216] added initial impl of SimpleVarInfo --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 105 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 src/simple_varinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..5cde57f91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -122,6 +122,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl new file mode 100644 index 000000000..6699a441b --- /dev/null +++ b/src/simple_varinfo.jl @@ -0,0 +1,105 @@ +""" + SimpleVarInfo{NT,T} <: AbstractVarInfo + +A simple wrapper of the parameters with a `logp` field for +accumulation of the logdensity. + +Currently only implemented for `NT <: NamedTuple`. + +## Notes +The major differences between this and `TypedVarInfo` are: +1. `SimpleVarInfo` does not require linearization. +2. `SimpleVarInfo` can use more efficient bijectors. +3. `SimpleVarInfo` only supports evaluation. +""" +struct SimpleVarInfo{NT,T} <: AbstractVarInfo + θ::NT + logp::Base.RefValue{T} +end + +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) + +function setlogp!(vi::SimpleVarInfo, logp) + vi.logp[] = logp + return vi +end + +function acclogp!(vi::SimpleVarInfo, logp) + vi.logp[] += logp + return vi +end + +function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} + # Use `getproperty` instead of `getfield` + value = getproperty(nt, sym) + return _getindex(value, inds) +end + +getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} = _getvalue(vi.θ, Val{sym}(), vn.indexing) +# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than +# just `Vector`. +getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) +# To disambiguiate. +getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) + +haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) + +istrans(::SimpleVarInfo, vn::VarName) = false + +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ +getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) +getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) + +# Context implementations +# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. +function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) + left = vi[vn] + return left, Distributions.loglikelihood(dist, left) +end + +# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) +# end + +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::SimpleVarInfo, +) + @assert length(dist) == size(var, 1) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = vi[vns] + lp = sum(zip(vns, eachcol(r))) do vn, ri + return Distributions.logpdf(dist, ri) + end + return r, lp +end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + vi::SimpleVarInfo{<:NamedTuple}, +) + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = vi[vns] + lp = sum(Distributions.logpdf.(dists, r)) + return r, lp +end + +# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. +increment_num_produce!(::SimpleVarInfo) = nothing From f91952d9f71641c9a69896f4363cff881b17517b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:16 +0100 Subject: [PATCH 061/216] remove unnecessary debug statements to be compat with Zygote --- src/context_implementations.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6833a7856..ab8fc7cab 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -644,8 +644,6 @@ function dot_observe( end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) - @debug "dist = $dist" - @debug "value = $value" return Distributions.loglikelihood(dist, value) end function dot_observe( @@ -658,8 +656,6 @@ function dot_observe( end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) - @debug "dists = $dists" - @debug "value = $value" return Distributions.loglikelihood(dists, value) end function dot_observe( @@ -672,7 +668,5 @@ function dot_observe( end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) - @debug "dists = $dists" - @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end From 4d4b4893085551be951f1c7d8c17c6767d91663d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 12:31:27 +0100 Subject: [PATCH 062/216] make reconstruct slightly more generic --- src/utils.jl | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..95b7f6a9a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,11 +93,10 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). -reconstruct(d::UnivariateDistribution, val::AbstractVector) = val[1] -reconstruct(d::MultivariateDistribution, val::AbstractVector) = copy(val) -function reconstruct(d::MatrixDistribution, val::AbstractVector) - return reshape(copy(val), size(d)) -end +reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) +reconstruct(::Tuple{}, val::AbstractVector) = val[1] +reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) +reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) function reconstruct!(r, d::Distribution, val::AbstractVector) return reconstruct!(r, d, val) end @@ -106,17 +105,17 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector) return r end function reconstruct(d::Distribution, val::AbstractVector, n::Int) - return reconstruct(d, val, n) + return reconstruct(size(d), val, n) end -function reconstruct(d::UnivariateDistribution, val::AbstractVector, n::Int) +function reconstruct(::Tuple{}, val::AbstractVector, n::Int) return copy(val) end -function reconstruct(d::MultivariateDistribution, val::AbstractVector, n::Int) - return copy(reshape(val, size(d)[1], n)) +function reconstruct(s::NTuple{1}, val::AbstractVector, n::Int) + return copy(reshape(val, s[1], n)) end -function reconstruct(d::MatrixDistribution, val::AbstractVector, n::Int) - tmp = reshape(val, size(d)[1], size(d)[2], n) - orig = [tmp[:, :, i] for i in 1:size(tmp, 3)] +function reconstruct(s::NTuple{2}, val::AbstractVector, n::Int) + tmp = reshape(val, s..., n) + orig = [tmp[:, :, i] for i in 1:n] return orig end function reconstruct!(r, d::Distribution, val::AbstractVector, n::Int) From a68c045ead9ad23acd0dcd8fdf41bd3a8174a3b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 10:04:23 +0100 Subject: [PATCH 063/216] added a couple of convenience constructors --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5cde57f91..8659c3b4e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -32,6 +32,7 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + SimpleVarInfo, getlogp, setlogp!, acclogp!, diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 6699a441b..3a2b01de1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -18,7 +18,7 @@ struct SimpleVarInfo{NT,T} <: AbstractVarInfo end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) -SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) +SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) function setlogp!(vi::SimpleVarInfo, logp) vi.logp[] = logp @@ -103,3 +103,20 @@ end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing + +# Interaction with `VarInfo` +SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real, names} + vals = map(names) do n + let md = getfield(vi.metadata, n) + x = map(enumerate(md.ranges)) do (i, r) + reconstruct(md.dists[i], md.vals[r]) + end + + # TODO: Doesn't support batches of `MultivariateDistribution`? + length(x) == 1 ? x[1] : x + end + end + + return SimpleVarInfo{T}(NamedTuple{names}(vals)) +end From 9766aecb82d9e5c9419047b14c96b6e78a1ff595 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 14:37:10 +0100 Subject: [PATCH 064/216] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3a2b01de1..1b46e0d0a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -36,7 +36,9 @@ function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} return _getindex(value, inds) end -getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} = _getvalue(vi.θ, Val{sym}(), vn.indexing) +function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} + return _getvalue(vi.θ, Val{sym}(), vn.indexing) +end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) @@ -106,7 +108,7 @@ increment_num_produce!(::SimpleVarInfo) = nothing # Interaction with `VarInfo` SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real, names} +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} vals = map(names) do n let md = getfield(vi.metadata, n) x = map(enumerate(md.ranges)) do (i, r) From 46b1c7884b8bd1fadb9fc5e8701b647b21508c57 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 18 Jun 2021 14:41:31 +0100 Subject: [PATCH 065/216] small fix --- src/simple_varinfo.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1b46e0d0a..ff7b8a7c2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -51,9 +51,12 @@ istrans(::SimpleVarInfo, vn::VarName) = false getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ +# TODO: Should we do better? getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) +# HACK: Need to disambiguiate. +getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Context implementations # Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. From 3a645d623d99e35f897f56f807e2b77f528f2e43 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:00:47 +0100 Subject: [PATCH 066/216] return var_info from tilde-statements, allowing impl of immutable versions --- src/compiler.jl | 35 +++++++++++++++++++++++++++------- src/context_implementations.jl | 13 +++++-------- src/model.jl | 13 +++++++++++++ src/utils.jl | 2 +- 4 files changed, 47 insertions(+), 16 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..920262002 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -299,7 +299,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -313,7 +313,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $left, __varinfo__ = $(DynamicPPL.tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn @@ -322,7 +322,7 @@ function generate_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -343,7 +343,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -357,7 +357,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn @@ -366,7 +366,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo = $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -378,6 +378,27 @@ function generate_dot_tilde(left, right) end end +replace_returns(e) = e +replace_returns(e::Symbol) = e +function replace_returns(e::Expr) + if Meta.isexpr(e, :function) || Meta.isexpr(e, :->) + return e + end + + if Meta.isexpr(e, :return) + retval = if length(e.args) > 1 + Expr(:tuple, e.args...) + else + e.args[1] + end + return quote + return $retval, __varinfo__ + end + end + + return Expr(e.head, map(x -> replace_returns(x), e.args)...) +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true @@ -409,7 +430,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = replace_returns(modelinfo[:body]) ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..b48520d03 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,8 +129,7 @@ probability of `vi` with the returned value. """ function tilde_assume!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) - acclogp!(vi, logp) - return value + return value, acclogp!(vi, logp) end # observe @@ -213,8 +212,7 @@ probability of `vi` with the returned value. """ function tilde_observe!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -415,8 +413,8 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ function dot_tilde_assume!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + left .= value + return value, acclogp!(vi, logp) end # `dot_assume` @@ -634,8 +632,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!(vi, logp) end # Falls back to non-sampler definition. diff --git a/src/model.jl b/src/model.jl index 9ec047a44..6fe2dcfa7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -155,6 +155,19 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context +) where {_F,argnames} + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :((first ∘ model.f)(model, varinfo, context, $(unwrap_args...))) +end + +""" + _evaluate_with_varinfo(model::Model, varinfo, context) + +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object, +also returning the resulting `varinfo`. +""" +@generated function _evaluate_with_varinfo( + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] return :(model.f(model, varinfo, context, $(unwrap_args...))) diff --git a/src/utils.jl b/src/utils.jl index 95b7f6a9a..2ff537fe4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!($(esc(:(__varinfo__))), $(esc(ex))) end end From a2ec0bd4c3cb51ab9aff759012addf76f7b39fd2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:01:14 +0100 Subject: [PATCH 067/216] allow usage of non-Ref types in SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index ff7b8a7c2..696da63f3 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -14,18 +14,22 @@ The major differences between this and `TypedVarInfo` are: """ struct SimpleVarInfo{NT,T} <: AbstractVarInfo θ::NT - logp::Base.RefValue{T} + logp::T end -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, Ref(zero(T))) +SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -function setlogp!(vi::SimpleVarInfo, logp) +getlogp(vi::SimpleVarInfo{<:Any, <:Real}) = vi.logp +setlogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) + +function setlogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo, logp) +function acclogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) vi.logp[] += logp return vi end From 1d9bc373cdb23fd70cca8cae26bc54af9a0d157e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:01:34 +0100 Subject: [PATCH 068/216] update submodel-macro --- src/submodel_macro.jl | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 1d574e286..96f1af6e6 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,15 +1,33 @@ macro submodel(expr) - return quote - _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + args_tilde = getargs_tilde(expr) + return if args_tilde === nothing + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + end + else + # Here we also want the return-variable. + L, R = args_tilde + quote + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__))) + end end end macro submodel(prefix, expr) - return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), - ) + ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) + + args_tilde = getargs_tilde(expr) + return if args_tilde === nothing + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + end + else + # Here we also want the return-variable. + L, R = args_tilde + quote + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(ctx)) + end end end From cfd7f219504bd1fdb3082032cb741e11644883a0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:26:44 +0100 Subject: [PATCH 069/216] formatting and docstring for submodel-macro --- src/simple_varinfo.jl | 10 +++++----- src/submodel_macro.jl | 26 ++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 696da63f3..2f029925b 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -20,16 +20,16 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -getlogp(vi::SimpleVarInfo{<:Any, <:Real}) = vi.logp -setlogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo{<:Any, <:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +getlogp(vi::SimpleVarInfo{<:Any,<:Real}) = vi.logp +setlogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) -function setlogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) +function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo{<:Any, <:Ref}, logp) +function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] += logp return vi end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 96f1af6e6..917d80cc4 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,15 +1,29 @@ +""" + @submodel x ~ model(args...) + @submodel prefix x ~ model(args...) + +Treats `model` as a distribution, where `x` is the return-value of `model`. + +If `prefix` is specified, then variables sampled within `model` will be +prefixed by `prefix`. This is useful if you have variables of same names in +several models used together. +""" macro submodel(expr) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__)) + ) end else # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__))) + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__)) + ) end end end @@ -21,13 +35,17 @@ macro submodel(prefix, expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(expr)), $(esc(:__varinfo__)), $(ctx) + ) end else # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo($(esc(R)), $(esc(:__varinfo__)), $(ctx)) + $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(R)), $(esc(:__varinfo__)), $(ctx) + ) end end end From c200e7362eb0717a81ea235dc2c17fc7ff07b357 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:42:51 +0100 Subject: [PATCH 070/216] attempt at supporting implicit returns too --- src/compiler.jl | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 920262002..1a9dc8929 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -386,12 +386,16 @@ function replace_returns(e::Expr) end if Meta.isexpr(e, :return) - retval = if length(e.args) > 1 + retval_expr = if length(e.args) > 1 Expr(:tuple, e.args...) else e.args[1] end + # Use intermediate variable since this expression + # can be more complex than just a value, e.g. `return if ... end`. + @gensym retval return quote + $retval = $retval_expr return $retval, __varinfo__ end end @@ -399,6 +403,18 @@ function replace_returns(e::Expr) return Expr(e.head, map(x -> replace_returns(x), e.args)...) end +# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. +make_returns_explicit!(body) = Expr(:return, body) +function make_returns_explicit!(body::Expr) + # If it's already a return-statement, we return immediately. + if Meta.isexpr(body, :return) + return body + end + + body.args[end] = Expr(:return, body.args[end]) + return body +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true @@ -430,7 +446,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = replace_returns(modelinfo[:body]) + evaluatordef[:body] = replace_returns(make_returns_explicit!(modelinfo[:body])) ## Build the model function. From efeb812c750ed0517e81f4b35dbcedf471f577fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:44:14 +0100 Subject: [PATCH 071/216] added a small comment --- src/compiler.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler.jl b/src/compiler.jl index 1a9dc8929..200285f02 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -411,6 +411,7 @@ function make_returns_explicit!(body::Expr) return body end + # Otherwise we replace the last statement with a `return` statement. body.args[end] = Expr(:return, body.args[end]) return body end From 14b94956bdb14587b7321267dba20d98d54b0171 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 02:47:56 +0100 Subject: [PATCH 072/216] simplifed submodel macro a bit --- src/submodel_macro.jl | 21 ++++----------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 917d80cc4..8e59f3015 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -9,28 +9,15 @@ prefixed by `prefix`. This is useful if you have variables of same names in several models used together. """ macro submodel(expr) - args_tilde = getargs_tilde(expr) - return if args_tilde === nothing - # In this case we only want to get the `__varinfo__`. - quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( - $(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__)) - ) - end - else - # Here we also want the return-variable. - L, R = args_tilde - quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( - $(esc(R)), $(esc(:__varinfo__)), $(esc(:__context__)) - ) - end - end + return submodel(expr) end macro submodel(prefix, expr) ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) + return submodel(expr, ctx) +end +function submodel(expr, ctx = esc(:__context__)) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. From c3d9e7b09aa3bec696b6ab06617d3088342bad66 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 03:31:51 +0100 Subject: [PATCH 073/216] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/submodel_macro.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 8e59f3015..32a2bd583 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -17,7 +17,7 @@ macro submodel(prefix, expr) return submodel(expr, ctx) end -function submodel(expr, ctx = esc(:__context__)) +function submodel(expr, ctx=esc(:__context__)) args_tilde = getargs_tilde(expr) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. From 416e7736be46b203e71f0660c544de989a78aa9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:51:17 +0100 Subject: [PATCH 074/216] fixed typo --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 200285f02..edf037b03 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -406,8 +406,8 @@ end # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) - # If it's already a return-statement, we return immediately. - if Meta.isexpr(body, :return) + # If the last statement is a return-statement, we don't do anything. + if Meta.isexpr(body.args[end], :return) return body end From b4b8b03edede2d3fee38c7be680f7117df8ff9b4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:54:11 +0100 Subject: [PATCH 075/216] use bang-bang convention --- src/compiler.jl | 12 ++++++------ src/context_implementations.jl | 30 +++++++++++++++--------------- src/loglikelihoods.jl | 16 ++++++++-------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index edf037b03..2b7945c9c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -299,7 +299,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - _, __varinfo__ = $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -313,7 +313,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left, __varinfo__ = $(DynamicPPL.tilde_assume!)( + $left, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn @@ -322,7 +322,7 @@ function generate_tilde(left, right) __varinfo__, ) else - _, __varinfo__ = $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, @@ -343,7 +343,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -357,7 +357,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn @@ -366,7 +366,7 @@ function generate_dot_tilde(left, right) __varinfo__, ) else - _, __varinfo = $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b48520d03..347d3403d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -119,7 +119,7 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. @@ -127,7 +127,7 @@ accumulate the log probability, and return the sampled value. By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) +function tilde_assume!!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) return value, acclogp!(vi, logp) end @@ -189,16 +189,16 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) - return tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vname, vinds, vi) + return tilde_observe!!(context, right, left, vi) end """ @@ -210,7 +210,7 @@ return the observed value. By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) return left, acclogp!(vi, logp) end @@ -404,14 +404,14 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) +function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) left .= value return value, acclogp!(vi, logp) @@ -610,27 +610,27 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) - return dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vn, inds, vi) + return dot_tilde_observe!!(context, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vi) + dot_tilde_observe!!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) return left, acclogp!(vi, logp) end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..4b1a16486 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -69,13 +69,13 @@ function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return tilde_observe!(context.context, right, left, vi) + return tilde_observe!!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. + # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) @@ -85,13 +85,13 @@ function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi return left end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe!(context.context, right, left, vi) + return dot_tilde_observe!!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `dot_tilde_observe!`. + # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) From a725a27c1a84c38de3e5921f76d3cb6897f93449 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 12:55:03 +0100 Subject: [PATCH 076/216] updated PointwiseLikelihoodContext --- src/loglikelihoods.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 4b1a16486..a12c8103c 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -77,12 +77,11 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, v # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!(vi, logp) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) @@ -93,12 +92,11 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!(vi, logp) end """ From 5512670ad6a1aa1d9b6e5df544202d3039d1e65c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:24:31 +0100 Subject: [PATCH 077/216] fixed issue where we unnecessarily replace the return-statement --- src/compiler.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2b7945c9c..09d3e9c9f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -386,6 +386,9 @@ function replace_returns(e::Expr) end if Meta.isexpr(e, :return) + # NOTE: `return` always has an argument. In the case of + # `return`, the parsed expression will be `return nothing`. + # Hence we don't need any special handling for empty returns. retval_expr = if length(e.args) > 1 Expr(:tuple, e.args...) else @@ -394,9 +397,15 @@ function replace_returns(e::Expr) # Use intermediate variable since this expression # can be more complex than just a value, e.g. `return if ... end`. @gensym retval + + # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return $retval, __varinfo__ + return if $retval isa Tuple{<:Any, $(DynamicPPL.AbstractVarInfo)} + $retval + else + $retval, __varinfo__ + end end end From 4c1ee70489b4cea67cb8a5795d7f664e4625897d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:27:17 +0100 Subject: [PATCH 078/216] check subtype in the retval --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 09d3e9c9f..937d6a368 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any, $(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{<:Any, <:$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From 26590b5c1f59851e2b748f7b83a35356fc094ed5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 13:27:34 +0100 Subject: [PATCH 079/216] formatting --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 937d6a368..561ea25af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any, <:$(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{<:Any,<:$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From 42fd4144fb0dc7fc435553271a984d81655e56f4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 14:19:05 +0100 Subject: [PATCH 080/216] fixed type-instability in retval check --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 561ea25af..1cba5181b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,7 +401,7 @@ function replace_returns(e::Expr) # If the return-value is already of the form we want, we don't do anything. return quote $retval = $retval_expr - return if $retval isa Tuple{<:Any,<:$(DynamicPPL.AbstractVarInfo)} + return if $retval isa Tuple{Any,$(DynamicPPL.AbstractVarInfo)} $retval else $retval, __varinfo__ From f219545c49b59aff2665b95d5b3caab7cd2e1b89 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 14:53:39 +0100 Subject: [PATCH 081/216] introduced evaluate method for model --- src/model.jl | 60 ++++++++++++++++++++++--------------------- src/submodel_macro.jl | 4 +-- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/model.jl b/src/model.jl index 6fe2dcfa7..d169aac8e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -82,17 +82,20 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( - rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return model(varinfo, SamplingContext(rng, sampler, context)) -end +(model::Model)(args...) = (first ∘ evaluate)(model, args...) + +""" + evaluate(model::Model[, rng, varinfo, sampler, context]) -(model::Model)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +Sample from the `model` using the `sampler` with random number generator `rng` and the +`context`, and store the sample and log joint probability in `varinfo`. + +Returns both the return-value of the original model, and the resulting varinfo. + +The method resets the log joint probability of `varinfo` and increases the evaluation +number of `sampler`. +""" +function evaluate(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -100,18 +103,30 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) - return model(Random.GLOBAL_RNG, args...) +function evaluate( + model::Model, + rng::Random.AbstractRNG, + varinfo::AbstractVarInfo=VarInfo(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return evaluate(model, varinfo, SamplingContext(rng, sampler, context)) +end + +evaluate(model::Model, context::AbstractContext) = evaluate(model, VarInfo(), context) + +function evaluate(model::Model, args...) + return evaluate(model, Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) - return model(rng, VarInfo(), sampler, args...) +function evaluate(model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...) + return evaluate(model, rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) - return model(rng, VarInfo(), SampleFromPrior(), context) +function evaluate(model::Model, rng::Random.AbstractRNG, context::AbstractContext) + return evaluate(model, rng, VarInfo(), SampleFromPrior(), context) end """ @@ -155,19 +170,6 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :((first ∘ model.f)(model, varinfo, context, $(unwrap_args...))) -end - -""" - _evaluate_with_varinfo(model::Model, varinfo, context) - -Evaluate the `model` with the arguments matching the given `context` and `varinfo` object, -also returning the resulting `varinfo`. -""" -@generated function _evaluate_with_varinfo( - model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] return :(model.f(model, varinfo, context, $(unwrap_args...))) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 32a2bd583..f9356a3c9 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -22,7 +22,7 @@ function submodel(expr, ctx=esc(:__context__)) return if args_tilde === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( $(esc(expr)), $(esc(:__varinfo__)), $(ctx) ) end @@ -30,7 +30,7 @@ function submodel(expr, ctx=esc(:__context__)) # Here we also want the return-variable. L, R = args_tilde quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate_with_varinfo( + $(esc(L)), $(esc(:__varinfo__)) = _evaluate( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) end From ce1356629de825529544c1f9bfb0bbaa8493328a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 15:51:16 +0100 Subject: [PATCH 082/216] remove unnecessary type-requirement --- src/simple_varinfo.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2f029925b..4e848e291 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -20,9 +20,9 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -getlogp(vi::SimpleVarInfo{<:Any,<:Real}) = vi.logp -setlogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo{<:Any,<:Real}, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +getlogp(vi::SimpleVarInfo) = vi.logp +setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp From 3556b111e612895a3848418f5db83b7c9a0ce7e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 16:00:21 +0100 Subject: [PATCH 083/216] make return-value check much nicer --- src/compiler.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1cba5181b..af7100772 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,24 +394,16 @@ function replace_returns(e::Expr) else e.args[1] end - # Use intermediate variable since this expression - # can be more complex than just a value, e.g. `return if ... end`. - @gensym retval - # If the return-value is already of the form we want, we don't do anything. - return quote - $retval = $retval_expr - return if $retval isa Tuple{Any,$(DynamicPPL.AbstractVarInfo)} - $retval - else - $retval, __varinfo__ - end - end + return :($(DynamicPPL.return_values)($retval_expr, __varinfo__)) end return Expr(e.head, map(x -> replace_returns(x), e.args)...) end +return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) +return_values(retval::Tuple{Any,AbstractVarInfo}, ::AbstractVarInfo) = retval + # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) From 599d09443a148f366f9a0f09f3defbeb3d5cd168 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 20 Jun 2021 16:14:16 +0100 Subject: [PATCH 084/216] removed redundant creation of anonymous function --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index af7100772..13318ff87 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -398,7 +398,7 @@ function replace_returns(e::Expr) return :($(DynamicPPL.return_values)($retval_expr, __varinfo__)) end - return Expr(e.head, map(x -> replace_returns(x), e.args)...) + return Expr(e.head, map(replace_returns, e.args)...) end return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) From 22b170c09a33f617c88e0d640c73c8c8594f5490 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Jun 2021 21:24:42 +0100 Subject: [PATCH 085/216] dont use UnionAll in return_values --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 13318ff87..bea869712 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -402,7 +402,7 @@ function replace_returns(e::Expr) end return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) -return_values(retval::Tuple{Any,AbstractVarInfo}, ::AbstractVarInfo) = retval +return_values(retval::Tuple{<:Any,<:AbstractVarInfo}, ::AbstractVarInfo) = retval # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) From 4606f163b867922e5e1bf7b0da8d3d3f156f8bdf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 28 Jun 2021 22:39:44 +0100 Subject: [PATCH 086/216] updated tests for submodel to reflect new syntax --- test/compiler.jl | 6 +++--- test/loglikelihoods.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 6f85e9453..703027fda 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -364,8 +364,8 @@ end end @model function demo_useval(x, y) - x1 = @submodel sub1 demo_return(x) - x2 = @submodel sub2 demo_return(y) + @submodel sub1 x1 ~ demo_return(x) + @submodel sub2 x2 ~ demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -399,7 +399,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) + @submodel $(Symbol("ar1_$i")) x ~ AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 74fb88d70..f1ded1a0f 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -69,7 +69,7 @@ end @model function gdemo9() # Submodel prior - m = @submodel _prior_dot_assume() + @submodel m ~ _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end From 68cb021885b1c0f603fd08e3e774fef1d42bf2bc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 14:50:18 +0100 Subject: [PATCH 087/216] moved to using BangBang-convention for most methods --- Project.toml | 1 + src/DynamicPPL.jl | 35 +++++++++ src/compat/ad.jl | 4 +- src/context_implementations.jl | 48 ++++++------- src/loglikelihoods.jl | 8 +-- src/model.jl | 6 +- src/submodel_macro.jl | 1 + src/threadsafe.jl | 30 ++++---- src/utils.jl | 2 +- src/varinfo.jl | 125 +++++++++++++++++---------------- test/threadsafe.jl | 6 +- test/varinfo.jl | 30 ++++---- 12 files changed, 167 insertions(+), 129 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8659c3b4e..88a5ad89f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -27,16 +27,25 @@ import Base: keys, haskey +import BangBang: + push!!, + empty!! + # VarInfo export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, + push!!, + empty!!, getlogp, setlogp!, acclogp!, resetlogp!, + setlogp!!, + acclogp!!, + resetlogp!!, get_num_produce, set_num_produce!, reset_num_produce!, @@ -45,12 +54,18 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, + set_flag!!, + unset_flag!!, setgid!, updategid!, + setgid!!, + updategid!!, setorder!, istrans, link!, invlink!, + link!!, + invlink!!, tonamedtuple, # VarName (reexport from AbstractPPL) VarName, @@ -132,4 +147,24 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# Deprecations +@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) + +@deprecate setlogp!(vi, logp) setlogp!!(vi, logp) +@deprecate acclogp!(vi, logp) acclogp!!(vi, logp) +@deprecate resetlogp!(vi) resetlogp!!(vi) + +@deprecate link!(vi, spl) link!!(vi, spl) +@deprecate invlink!(vi, spl) invlink!!(vi, spl) + +@deprecate set_flag!(vi, vn, flag) set_flag!!(vi, vn, flag) +@deprecate unset_flag!(vi, vn, flag) unset_flag!!(vi, vn, flag) + +@deprecate setgid!(vi, gid, vn) setgid!!(vi, gid, vn) +@deprecate updategid!(vi, vn, spl) updategid!!(vi, vn, spl) + end # module diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 47a627506..664ce2b33 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -1,9 +1,9 @@ # See https://github.com/TuringLang/Turing.jl/issues/1199 -ChainRulesCore.@non_differentiable push!( +ChainRulesCore.@non_differentiable push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) -ChainRulesCore.@non_differentiable updategid!( +ChainRulesCore.@non_differentiable updategid!!( vi::AbstractVarInfo, vn::VarName, spl::Sampler ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 347d3403d..64a958644 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, inds, vi) end @@ -60,7 +60,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end @@ -74,7 +74,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end @@ -89,7 +89,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end @@ -129,7 +129,7 @@ probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) - return value, acclogp!(vi, logp) + return value, acclogp!!(vi, logp) end # observe @@ -212,7 +212,7 @@ probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -243,18 +243,18 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") + unset_flag!!(vi, vn, "del") r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else r = vi[vn] end else r = init(rng, dist, sampler) - push!(vi, vn, r, dist, sampler) - settrans!(vi, false, vn) + push!!(vi, vn, r, dist, sampler) + settrans!!(vi, false, vn) end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) @@ -305,7 +305,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) @@ -325,7 +325,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) @@ -346,7 +346,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) @@ -366,7 +366,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!!.(Ref(vi), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) @@ -414,7 +414,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) left .= value - return value, acclogp!(vi, logp) + return value, acclogp!!(vi, logp) end # `dot_assume` @@ -495,12 +495,12 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + unset_flag!!(vi, vns[1], "del") r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] vi[vn] = vectorize(dist, r[:, i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -510,8 +510,8 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!(vi, vn, r[:, i], dist, spl) - settrans!(vi, false, vn) + push!!(vi, vn, r[:, i], dist, spl) + settrans!!(vi, false, vn) end end return r @@ -527,14 +527,14 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!(vi, vns[1], "del") + unset_flag!!(vi, vns[1], "del") f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists vi[vn] = vectorize(dist, r[i]) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -543,8 +543,8 @@ function get_and_set_val!( else f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) - push!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + push!!.(Ref(vi), vns, r, dists, Ref(spl)) + settrans!!.(Ref(vi), false, vns) end return r end @@ -632,7 +632,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end # Falls back to non-sampler definition. diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index a12c8103c..4ca015f3e 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -74,14 +74,14 @@ function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) return tilde_observe!!(context.context, right, left, vi) end function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # Need the `logp` value, so we cannot defer `acclogp!!` to child-context, i.e. # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vi) @@ -89,14 +89,14 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v return dot_tilde_observe!!(context.context, right, left, vi) end function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # Need the `logp` value, so we cannot defer `acclogp!!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!!`. logp = dot_tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) - return left, acclogp!(vi, logp) + return left, acclogp!!(vi, logp) end """ diff --git a/src/model.jl b/src/model.jl index d169aac8e..929f420ff 100644 --- a/src/model.jl +++ b/src/model.jl @@ -140,7 +140,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ function evaluate_threadunsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) return _evaluate(model, varinfo, context) end @@ -156,10 +156,10 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ function evaluate_threadsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) result = _evaluate(model, wrapper, context) - setlogp!(varinfo, getlogp(wrapper)) + setlogp!!(varinfo, getlogp(wrapper)) return result end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index f9356a3c9..25460eada 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -28,6 +28,7 @@ function submodel(expr, ctx=esc(:__context__)) end else # Here we also want the return-variable. + # TODO: Should we prefix by `L` by default? L, R = args_tilde quote $(esc(L)), $(esc(:__varinfo__)) = _evaluate( diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c940f9e3f..9c59fa507 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. -function acclogp!(vi::ThreadSafeVarInfo, logp) +function acclogp!!(vi::ThreadSafeVarInfo, logp) vi.logps[Threads.threadid()][] += logp return vi end @@ -26,17 +26,17 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps) # TODO: Make remaining methods thread-safe. -function resetlogp!(vi::ThreadSafeVarInfo) +function resetlogp!!(vi::ThreadSafeVarInfo) for x in vi.logps x[] = zero(x[]) end - return resetlogp!(vi.varinfo) + return resetlogp!!(vi.varinfo) end -function setlogp!(vi::ThreadSafeVarInfo, logp) +function setlogp!!(vi::ThreadSafeVarInfo, logp) for x in vi.logps x[] = zero(x[]) end - return setlogp!(vi.varinfo, logp) + return setlogp!!(vi.varinfo, logp) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -46,8 +46,8 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!(vi.varinfo, gid, vn) +function setgid!!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) + return setgid!!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -55,8 +55,8 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) +link!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!!(vi.varinfo, spl) +invlink!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) @@ -80,20 +80,20 @@ function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function empty!(vi::ThreadSafeVarInfo) - empty!(vi.varinfo) +function empty!!(vi::ThreadSafeVarInfo) + empty!!(vi.varinfo) fill!(vi.logps, zero(getlogp(vi))) return vi end -function push!( +function push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return push!(vi.varinfo, vn, r, dist, gidset) + return push!!(vi.varinfo, vn, r, dist, gidset) end -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) +function unset_flag!!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) + return unset_flag!!(vi.varinfo, vn, flag) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/utils.jl b/src/utils.jl index 2ff537fe4..14b7650fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - $(esc(:(__varinfo__))) = acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) end end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..1f81f692d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -335,14 +335,15 @@ getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) end """ - setall!(vi::VarInfo, val) + setall!!(vi::VarInfo, val) -Set the values of all the variables in `vi` to `val`. +Set the values of all the variables in `vi` to `val`, +mutating if it makese sense. The values may or may not be transformed to Euclidean space. """ -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) +setall!!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val +setall!!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} expr = Expr(:block) start = :(1) @@ -363,12 +364,12 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!!(vi::VarInfo, trans::Bool, vn::VarName) -Set the `trans` flag value of `vn` in `vi`. +Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) + return trans ? set_flag!!(vi, vn, "trans") : unset_flag!!(vi, vn, "trans") end """ @@ -504,11 +505,11 @@ end end """ - set_flag!(vi::VarInfo, vn::VarName, flag::String) + set_flag!!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `true` in `vi`. """ -function set_flag!(vi::VarInfo, vn::VarName, flag::String) +function set_flag!!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end @@ -586,16 +587,16 @@ end TypedVarInfo(vi::TypedVarInfo) = vi """ - empty!(vi::VarInfo) + empty!!(vi::VarInfo) Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. +zeros, mutating if it makes sense. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ -function empty!(vi::VarInfo) +function empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!(vi) + resetlogp!!(vi) reset_num_produce!(vi) return vi end @@ -628,11 +629,11 @@ Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) end """ - setgid!(vi::VarInfo, gid::Selector, vn::VarName) + setgid!!(vi::VarInfo, gid::Selector, vn::VarName) Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!(vi::VarInfo, gid::Selector, vn::VarName) +function setgid!!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end @@ -653,34 +654,34 @@ Return the log of the joint probability of the observed data and parameters samp getlogp(vi::AbstractVarInfo) = vi.logp[] """ - setlogp!(vi::VarInfo, logp) + setlogp!!(vi::VarInfo, logp) Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`. +`vi` to `logp`, mutating if it makes sense. """ -function setlogp!(vi::VarInfo, logp) +function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end """ - acclogp!(vi::VarInfo, logp) + acclogp!!(vi::VarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`. +parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!(vi::VarInfo, logp) +function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end """ - resetlogp!(vi::AbstractVarInfo) + resetlogp!!(vi::AbstractVarInfo) Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0. +sampled in `vi` to 0, mutating if it makes sense. """ -resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) """ get_num_produce(vi::VarInfo) @@ -728,13 +729,13 @@ end # X -> R for all variables associated with given sampler """ - link!(vi::VarInfo, spl::Sampler) + link!!(vi::VarInfo, spl::Sampler) Transform the values of the random variables sampled by `spl` in `vi` from the support of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function link!!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -747,16 +748,16 @@ function link!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - return link!(vi, spl, Val(getspace(spl))) +function link!!(vi::TypedVarInfo, spl::AbstractSampler) + return link!!(vi, spl, Val(getspace(spl))) end -function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function link!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -783,7 +784,7 @@ end ), vn, ) - settrans!(vi, true, vn) + settrans!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -797,13 +798,13 @@ end # R -> X for all variables associated with given sampler """ - invlink!(vi::VarInfo, spl::AbstractSampler) + invlink!!(vi::VarInfo, spl::AbstractSampler) Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @@ -814,16 +815,16 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return invlink!(vi, spl, Val(getspace(spl))) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler) + return invlink!!(vi, spl, Val(getspace(spl))) end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end @@ -852,7 +853,7 @@ end ), vn, ) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -962,7 +963,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) +setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` @@ -1086,42 +1087,42 @@ function Base.show(io::IO, vi::UntypedVarInfo) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`. +the `VarInfo` `vi`, mutating if it makes sense. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!(vi, vn, r, dist, Set{Selector}([])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return push!!(vi, vn, r, dist, Set{Selector}([])) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`. +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - return push!(vi, vn, r, dist, spl.selector) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) + return push!!(vi, vn, r, dist, spl.selector) end -function push!( +function push!!( vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler ) - return push!(vi, vn, r, dist) + return push!!(vi, vn, r, dist) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!(vi, vn, r, dist, Set([gid])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + return push!!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo @@ -1174,11 +1175,11 @@ function is_flagged(vi::VarInfo, vn::VarName, flag::String) end """ - unset_flag!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `false` in `vi`. """ -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) +function unset_flag!!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false end @@ -1238,14 +1239,14 @@ end end """ - updategid!(vi::VarInfo, vn::VarName, spl::Sampler) + updategid!!(vi::VarInfo, vn::VarName, spl::Sampler) Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked and `vn`'s symbol is in the space of `spl`. """ -function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) +function updategid!!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) if inspace(vn, getspace(spl)) - setgid!(vi, spl.selector, vn) + setgid!!(vi, spl.selector, vn) end end @@ -1393,7 +1394,7 @@ function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) end return indices @@ -1474,11 +1475,11 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!(vi, false, vn) + settrans!!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. - set_flag!(vi, vn, "del") + set_flag!!(vi, vn, "del") end return indices diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 83c53ccd6..bd1f4f154 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -17,17 +17,17 @@ lp = getlogp(vi) @test getlogp(threadsafe_vi) == lp - acclogp!(threadsafe_vi, 42) + acclogp!!(threadsafe_vi, 42) @test threadsafe_vi.logps[Threads.threadid()][] == 42 @test getlogp(vi) == lp @test getlogp(threadsafe_vi) == lp + 42 - resetlogp!(threadsafe_vi) + resetlogp!!(threadsafe_vi) @test iszero(getlogp(vi)) @test iszero(getlogp(threadsafe_vi)) @test all(iszero(x[]) for x in threadsafe_vi.logps) - setlogp!(threadsafe_vi, 42) + setlogp!!(threadsafe_vi, 42) @test getlogp(vi) == 42 @test getlogp(threadsafe_vi) == 42 @test all(iszero(x[]) for x in threadsafe_vi.logps) diff --git a/test/varinfo.jl b/test/varinfo.jl index 4c8ec43cb..f1cadfa8f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -33,7 +33,7 @@ end @testset "Base" begin # Test Base functions: - # string, Symbol, ==, hash, in, keys, haskey, isempty, push!, empty!, + # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! csym = gensym() vn1 = @varname x[1][2] @@ -46,7 +46,7 @@ @test inspace(vn1, (:x,)) function test_base!(vi) - empty!(vi) + empty!!(vi) @test getlogp(vi) == 0 @test get_num_produce(vi) == 0 @@ -58,7 +58,7 @@ @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -75,9 +75,9 @@ @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!(vi) + empty!!(vi) @test isempty(vi) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) function test_inspace() space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) @@ -98,7 +98,7 @@ end vi = VarInfo() test_base!(vi) - test_base!(empty!(TypedVarInfo(vi))) + test_base!(empty!!(TypedVarInfo(vi))) end @testset "flags" begin # Test flag setting: @@ -109,20 +109,20 @@ r = rand(dist) gid = Selector() - push!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist, gid) # del is set by default @test !is_flagged(vi, vn_x, "del") - set_flag!(vi, vn_x, "del") + set_flag!!(vi, vn_x, "del") @test is_flagged(vi, vn_x, "del") - unset_flag!(vi, vn_x, "del") + unset_flag!!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!(TypedVarInfo(vi))) + test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin vi = VarInfo() @@ -133,16 +133,16 @@ gid1 = Selector() gid2 = Selector(2, :HMC) - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) + setgid!!(vi, gid2, vn) @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - vi = empty!(TypedVarInfo(vi)) + vi = empty!!(TypedVarInfo(vi)) meta = vi.metadata - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!(vi, gid2, vn) + setgid!!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end @testset "setval! & setval_and_resample!" begin From cb1fd8bd2f6da1a776082add522e0aa4a8d76b90 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 14:51:18 +0100 Subject: [PATCH 088/216] remove SimpleVarInfo from this branch --- src/DynamicPPL.jl | 2 - src/simple_varinfo.jl | 131 ------------------------------------------ 2 files changed, 133 deletions(-) delete mode 100644 src/simple_varinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 88a5ad89f..bda0b897b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -36,7 +36,6 @@ export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, - SimpleVarInfo, push!!, empty!!, getlogp, @@ -138,7 +137,6 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") -include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 4e848e291..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,131 +0,0 @@ -""" - SimpleVarInfo{NT,T} <: AbstractVarInfo - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT <: NamedTuple`. - -## Notes -The major differences between this and `TypedVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` only supports evaluation. -""" -struct SimpleVarInfo{NT,T} <: AbstractVarInfo - θ::NT - logp::T -end - -SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) - -getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) - -function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] = logp - return vi -end - -function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) - vi.logp[] += logp - return vi -end - -function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} - # Use `getproperty` instead of `getfield` - value = getproperty(nt, sym) - return _getindex(value, inds) -end - -function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - return _getvalue(vi.θ, Val{sym}(), vn.indexing) -end -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) -# To disambiguiate. -getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) - -haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) - -istrans(::SimpleVarInfo, vn::VarName) = false - -getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ -getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ -# TODO: Should we do better? -getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ -getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) -getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) -# HACK: Need to disambiguiate. -getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) - -# Context implementations -# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. -function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) - left = vi[vn] - return left, Distributions.loglikelihood(dist, left) -end - -# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) -# end - -function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi::SimpleVarInfo, -) - @assert length(dist) == size(var, 1) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri - return Distributions.logpdf(dist, ri) - end - return r, lp -end - -function dot_assume( - dists::Union{Distribution,AbstractArray{<:Distribution}}, - var::AbstractArray, - vns::AbstractArray{<:VarName}, - vi::SimpleVarInfo{<:NamedTuple}, -) - # NOTE: We cannot work with `var` here because we might have a model of the form - # - # m = Vector{Float64}(undef, n) - # m .~ Normal() - # - # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(Distributions.logpdf.(dists, r)) - return r, lp -end - -# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. -increment_num_produce!(::SimpleVarInfo) = nothing - -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end From 5936dd059f0dc5815832c2b8a6c5a421a97c994e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 15:51:54 +0100 Subject: [PATCH 089/216] added a comment --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 64a958644..278d7a5ad 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -413,6 +413,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, inds, vi) value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. left .= value return value, acclogp!!(vi, logp) end From 426c465d4206e2f6e8403f6aa602262f8663c815 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 15:51:59 +0100 Subject: [PATCH 090/216] reverted submodel macro to use = rather than ~ --- src/submodel_macro.jl | 10 +++++----- src/utils.jl | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 25460eada..23b6245ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,6 +1,6 @@ """ - @submodel x ~ model(args...) - @submodel prefix x ~ model(args...) + @submodel x = model(args...) + @submodel prefix x = model(args...) Treats `model` as a distribution, where `x` is the return-value of `model`. @@ -18,8 +18,8 @@ macro submodel(prefix, expr) end function submodel(expr, ctx=esc(:__context__)) - args_tilde = getargs_tilde(expr) - return if args_tilde === nothing + args_assign = getargs_assignment(expr) + return if args_assign === nothing # In this case we only want to get the `__varinfo__`. quote $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( @@ -29,7 +29,7 @@ function submodel(expr, ctx=esc(:__context__)) else # Here we also want the return-variable. # TODO: Should we prefix by `L` by default? - L, R = args_tilde + L, R = args_assign quote $(esc(L)), $(esc(:__varinfo__)) = _evaluate( $(esc(R)), $(esc(:__varinfo__)), $(ctx) diff --git a/src/utils.jl b/src/utils.jl index 14b7650fb..76efe2298 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,6 +44,21 @@ function getargs_tilde(expr::Expr) end end +""" + getargs_assignment(x) + +Return the arguments `L` and `R`, if `x` is an expression of the form `L = R`, or `nothing` +otherwise. +""" +getargs_assignment(x) = nothing +function getargs_assignment(expr::Expr) + return MacroTools.@match expr begin + (L_ = R_) => (L, R) + x_ => nothing + end +end + + ############################################ # Julia 1.2 temporary fix - Julia PR 33303 # ############################################ From a8e55bd0a6cb0798b980dbb36be9f3b263d3875d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:10:20 +0100 Subject: [PATCH 091/216] updated SimpleVarInfo impl --- src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a26516283..88a5ad89f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -138,6 +138,7 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") +include("simple_varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") include("compiler.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4e848e291..501aa2185 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -21,15 +21,15 @@ SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) -function setlogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp return vi end -function acclogp!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) +function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] += logp return vi end @@ -69,8 +69,8 @@ function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple} return left, Distributions.loglikelihood(dist, left) end -# function dot_tilde_assume!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!, (context, right, left, vn, inds, vi))) +# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) +# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) # end function dot_assume( From 149229f0dacd265e46691ee9c8c3f18af867712c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:21:00 +0100 Subject: [PATCH 092/216] added a couple of missing deprecations --- src/DynamicPPL.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bda0b897b..62bf73b83 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -162,6 +162,10 @@ include("submodel_macro.jl") @deprecate set_flag!(vi, vn, flag) set_flag!!(vi, vn, flag) @deprecate unset_flag!(vi, vn, flag) unset_flag!!(vi, vn, flag) +@deprecate settrans!(vi, trans, vn) settrans!!(vi, trans, vn) + +@deprecate setall!(vi, val) setall!!(vi, val) + @deprecate setgid!(vi, gid, vn) setgid!!(vi, gid, vn) @deprecate updategid!(vi, vn, spl) updategid!!(vi, vn, spl) From 809d23fbf58bbc9d08fafea9d28bccbe195afac0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:21:35 +0100 Subject: [PATCH 093/216] updated tests --- test/compiler.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 703027fda..2b9a2273b 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -364,8 +364,8 @@ end end @model function demo_useval(x, y) - @submodel sub1 x1 ~ demo_return(x) - @submodel sub2 x2 ~ demo_return(y) + @submodel sub1 x1 = demo_return(x) + @submodel sub2 x2 = demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -399,7 +399,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - @submodel $(Symbol("ar1_$i")) x ~ AR1(num_steps, α, μ, σ) + @submodel $(Symbol("ar1_$i")) x = AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end From 07f684b0e9152172915d9dcf2226541c8ebaa5ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:47:08 +0100 Subject: [PATCH 094/216] updated implementations of logjoint and others --- src/model.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/model.jl b/src/model.jl index 929f420ff..e3c83528a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -204,8 +204,8 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, DefaultContext()) + return getlogp(varinfo_new) end """ @@ -216,8 +216,8 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, PriorContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, PriorContext()) + return getlogp(varinfo_new) end """ @@ -228,8 +228,8 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, LikelihoodContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, LikelihoodContext()) + return getlogp(varinfo_new) end """ From b00ae474c19e1e1d9c3580576e4b17db803f3174 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 30 Jun 2021 16:48:33 +0100 Subject: [PATCH 095/216] formatting --- src/DynamicPPL.jl | 20 +++++++++++++------- src/utils.jl | 1 - 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 62bf73b83..389447344 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -27,9 +27,7 @@ import Base: keys, haskey -import BangBang: - push!!, - empty!! +import BangBang: push!!, empty!! # VarInfo export AbstractVarInfo, @@ -147,10 +145,18 @@ include("submodel_macro.jl") # Deprecations @deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) -@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector} +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) @deprecate setlogp!(vi, logp) setlogp!!(vi, logp) @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) diff --git a/src/utils.jl b/src/utils.jl index 76efe2298..de1281ac8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -58,7 +58,6 @@ function getargs_assignment(expr::Expr) end end - ############################################ # Julia 1.2 temporary fix - Julia PR 33303 # ############################################ From bfd7c789639df0395e33e0c1bf20557c27f60aa1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:42:50 +0100 Subject: [PATCH 096/216] added eltype impl for SimpleVarInfo --- src/simple_varinfo.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 501aa2185..d5ca2fc13 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -62,6 +62,11 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) # HACK: Need to disambiguiate. getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) +# Necessary for `matchingvalue` to work properly. +function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) + return T +end + # Context implementations # Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) From acb15eb9b9525eda0b57036f2b6864623d8ab1d3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 2 Jul 2021 03:45:38 +0100 Subject: [PATCH 097/216] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d5ca2fc13..12437844e 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -63,7 +63,9 @@ getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. -function Base.eltype(vi::SimpleVarInfo{<:Any, T}, spl::Union{AbstractSampler,SampleFromPrior}) +function Base.eltype( + vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} +) return T end From 4828aab2f3ee108f286b461a057f8909c9dfbc4e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jul 2021 10:56:52 +0100 Subject: [PATCH 098/216] fixed eltype for SimpleVarInfo --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 12437844e..c88bf0192 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -65,7 +65,7 @@ getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) # Necessary for `matchingvalue` to work properly. function Base.eltype( vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} -) +) where {T} return T end @@ -136,3 +136,5 @@ function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names return SimpleVarInfo{T}(NamedTuple{names}(vals)) end + +SimpleVarInfo(model::Model, args...) = SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) From 167976f64e1134043a98bc74ebbe7692285c01c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 22:49:20 +0100 Subject: [PATCH 099/216] implement setindex!! in prep for allowing sampling with immutable vi --- src/DynamicPPL.jl | 2 +- src/context_implementations.jl | 24 ++++++++++++++++-------- src/varinfo.jl | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 389447344..77d271923 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -27,7 +27,7 @@ import Base: keys, haskey -import BangBang: push!!, empty!! +import BangBang: push!!, empty!!, setindex!! # VarInfo export AbstractVarInfo, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 278d7a5ad..44fc74d4f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -44,8 +44,10 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!!(vi, false, vn) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) + vi = settrans!!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, inds, vi) end @@ -59,8 +61,10 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!!(vi, false, vn) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) + vi = settrans!!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end @@ -73,8 +77,10 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!!(vi, false, vn) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) + vi = settrans!!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end @@ -88,8 +94,10 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!!(vi, false, vn) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) + vi = settrans!!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 1f81f692d..a755cc747 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -955,6 +955,7 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) +setindex!!(vi::AbstractVarInfo, val, vn::VarName) = setindex!(vi, val, vn) """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) From e4f0ad263a862b3157df20fc39e795e88c860a3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 22:50:16 +0100 Subject: [PATCH 100/216] formatting --- src/simple_varinfo.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c88bf0192..147865cc0 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -137,4 +137,6 @@ function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names return SimpleVarInfo{T}(NamedTuple{names}(vals)) end -SimpleVarInfo(model::Model, args...) = SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) +function SimpleVarInfo(model::Model, args...) + return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) +end From ccfd112d3b55d67e8c886bb435358d5b2b3f38dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 22:50:58 +0100 Subject: [PATCH 101/216] initial work on allowing sampling using SimpleVarInfo --- Project.toml | 1 + src/simple_varinfo.jl | 56 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 7cc3db31d..cf8dfab90 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 147865cc0..859ef0540 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,3 +1,5 @@ +using Setfield + """ SimpleVarInfo{NT,T} <: AbstractVarInfo @@ -19,6 +21,8 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) +SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) +SimpleVarInfo() = SimpleVarInfo{Float64}() getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) @@ -69,17 +73,48 @@ function Base.eltype( return T end +function push!!(vi::SimpleVarInfo{Nothing}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} + @set vi.θ = NamedTuple{(sym, )}((value, )) +end +function push!!(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} + @set vi.θ = merge(vi.θ, NamedTuple{(sym, )}((value, ))) +end + # Context implementations -# Only evaluation makes sense for `SimpleVarInfo`, so we only implement this. -function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo{<:NamedTuple}) +function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) + return value, acclogp!!(vi_new, logp) +end + +function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) left = vi[vn] - return left, Distributions.loglikelihood(dist, left) + return left, Distributions.loglikelihood(dist, left), vi +end + +function assume( + rng::Random.AbstractRNG, + sampler::SampleFromPrior, + dist::Distribution, + vn::VarName, + vi::SimpleVarInfo +) + value = init(rng, dist, sampler) + vi = push!!(vi, vn, value, dist, sampler) + vi = settrans!!(vi, false, vn) + return value, Distributions.loglikelihood(dist, value), vi end # function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) # throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) # end +function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) + value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. + left .= value + return value, acclogp!!(vi_new, logp) +end + function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, @@ -93,11 +128,11 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri - return Distributions.logpdf(dist, ri) + value = vi[vns] + lp = sum(zip(vns, eachcol(value))) do vn, val + return Distributions.logpdf(dist, val) end - return r, lp + return value, lp, vi end function dot_assume( @@ -112,13 +147,14 @@ function dot_assume( # m .~ Normal() # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. - r = vi[vns] - lp = sum(Distributions.logpdf.(dists, r)) - return r, lp + value = vi[vns] + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing +settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi # Interaction with `VarInfo` SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) From d660433c4a8620a118d19a65caa192ff7baeebbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 9 Jul 2021 23:22:48 +0100 Subject: [PATCH 102/216] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/simple_varinfo.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 859ef0540..2796a015c 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -73,11 +73,15 @@ function Base.eltype( return T end -function push!!(vi::SimpleVarInfo{Nothing}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} - @set vi.θ = NamedTuple{(sym, )}((value, )) +function push!!( + vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution +) where {sym} + @set vi.θ = NamedTuple{(sym,)}((value,)) end -function push!!(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym, Tuple{}}, value, dist::Distribution) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(sym, )}((value, ))) +function push!!( + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution +) where {sym} + @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end # Context implementations @@ -96,7 +100,7 @@ function assume( sampler::SampleFromPrior, dist::Distribution, vn::VarName, - vi::SimpleVarInfo + vi::SimpleVarInfo, ) value = init(rng, dist, sampler) vi = push!!(vi, vn, value, dist, sampler) From 90cf754b3a51662b84ebef33a9b277218c0597e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:06:33 +0100 Subject: [PATCH 103/216] add constructor for SimpleVarInfo using model --- src/simple_varinfo.jl | 44 ++++++++++++++++++++++--------------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2796a015c..87d9b0516 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -24,6 +24,29 @@ SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) SimpleVarInfo() = SimpleVarInfo{Float64}() +# Interaction with `VarInfo` +SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) +function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} + vals = map(names) do n + let md = getfield(vi.metadata, n) + x = map(enumerate(md.ranges)) do (i, r) + reconstruct(md.dists[i], md.vals[r]) + end + + # TODO: Doesn't support batches of `MultivariateDistribution`? + length(x) == 1 ? x[1] : x + end + end + + return SimpleVarInfo{T}(NamedTuple{names}(vals)) +end + +SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) +function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} + _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) + return svi +end + getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) @@ -159,24 +182,3 @@ end # HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi - -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end - -function SimpleVarInfo(model::Model, args...) - return SimpleVarInfo(VarInfo(Random.GLOBAL_RNG, model, args...)) -end From 0ab9d8b40a39c83a2f8ed473ac45306f22775ddd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:08:39 +0100 Subject: [PATCH 104/216] improved leftover to_namedtuple_expr, fixing a bug when used with Zygote --- src/utils.jl | 37 +++++-------------------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e77a4ecdd..db7faabbd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,39 +44,12 @@ function getargs_tilde(expr::Expr) end end -############################################ -# Julia 1.2 temporary fix - Julia PR 33303 # -############################################ function to_namedtuple_expr(syms, vals=syms) - if length(syms) == 0 - nt = :(NamedTuple()) - else - nt_type = Expr( - :curly, - :NamedTuple, - Expr(:tuple, QuoteNode.(syms)...), - Expr(:curly, :Tuple, [:(Core.Typeof($x)) for x in vals]...), - ) - nt = Expr(:call, :($(DynamicPPL.namedtuple)), nt_type, Expr(:tuple, vals...)) - end - return nt -end - -if VERSION == v"1.2" - @eval function namedtuple( - ::Type{NamedTuple{names,T}}, args::Tuple - ) where {names,T<:Tuple} - if length(args) != length(names) - throw(ArgumentError("Wrong number of arguments to named tuple constructor.")) - end - # Note T(args) might not return something of type T; e.g. - # Tuple{Type{Float64}}((Float64,)) returns a Tuple{DataType} - return $(Expr(:splatnew, :(NamedTuple{names,T}), :(T(args)))) - end -else - function namedtuple(::Type{NamedTuple{names,T}}, args::Tuple) where {names,T<:Tuple} - return NamedTuple{names,T}(args) - end + length(syms) == 0 && return :(NamedTuple()) + + names_expr = Expr(:tuple, QuoteNode.(syms)...) + vals_expr = Expr(:tuple, vals...) + return :(NamedTuple{$names_expr}($vals_expr)) end ##################################################### From 42ad5524d3a29f650e405079c1d8921d4845b536 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 08:09:22 +0100 Subject: [PATCH 105/216] bumped patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e17343312..e2c1dd29e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.12.2" +version = "0.12.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 0f1def98e2205dea9ca49925db3591b505f04985 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 20 Jul 2021 22:16:38 +0100 Subject: [PATCH 106/216] fixed set_flag!! --- src/varinfo.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 527e7a693..007e3cf3d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -515,7 +515,8 @@ end Set `vn`'s value for `flag` to `true` in `vi`. """ function set_flag!!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true + getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true + return vi end #### From 53596fb29b173a2387ab098dfaed84d99851900b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 23 Jul 2021 16:20:03 +0100 Subject: [PATCH 107/216] forgot the return in the replace_returns --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index a1aafe4d0..a996365c2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -405,7 +405,7 @@ function replace_returns(e::Expr) e.args[1] end - return :($(DynamicPPL.return_values)($retval_expr, __varinfo__)) + return :(return $(DynamicPPL.return_values)($retval_expr, __varinfo__)) end return Expr(e.head, map(replace_returns, e.args)...) From 57b5d47635b52fb40a64ca230588f980cbbe4b1e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Aug 2021 02:08:46 +0100 Subject: [PATCH 108/216] bigboy update to benchmarks --- benchmarks/Project.toml | 4 + benchmarks/benchmark_body.jmd | 19 ++- benchmarks/benchmarks.jmd | 104 ++++++++++++-- benchmarks/src/DynamicPPLBenchmarks.jl | 4 + benchmarks/src/tables.jl | 187 +++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 16 deletions(-) create mode 100644 benchmarks/src/tables.jl diff --git a/benchmarks/Project.toml b/benchmarks/Project.toml index a8e8f09a2..84e582e43 100644 --- a/benchmarks/Project.toml +++ b/benchmarks/Project.toml @@ -3,11 +3,15 @@ uuid = "d94a1522-c11e-44a7-981a-42bf5dc1a001" version = "0.1.0" [deps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" DiffUtils = "8294860b-85a6-42f8-8c35-d911f667b5f6" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DrWatson = "634d3b9d-ee7a-5ddf-bec9-22491ea816e1" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LibGit2 = "76f85450-5226-5b5a-8eaa-529ad045b433" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Weave = "44d3d7a6-8a23-5bf8-98c5-b353f8df5ec9" diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index f9c994dc9..9d140ad82 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -1,22 +1,29 @@ ```julia -@time model_def(data)(); +@time model_def(data...)(); ``` ```julia -m = time_model_def(model_def, data); +m = time_model_def(model_def, data...); ``` ```julia suite = make_suite(m); -results = run(suite) -results +results = run(suite, seconds=WEAVE_ARGS[:seconds]); +``` + +```julia; displaysize=(100, 300) +results["evaluation_untyped"] +``` + +```julia; displaysize=(100, 300) +results["evaluation_typed"] ``` ```julia; echo=false; results="hidden"; BenchmarkTools.save(joinpath("results", WEAVE_ARGS[:name], "$(m.name)_benchmarks.json"), results) ``` -```julia; wrap=false +```julia; wrap=false; echo=false if WEAVE_ARGS[:include_typed_code] typed = typed_code(m) end @@ -30,7 +37,7 @@ end ``` ```julia; wrap=false; echo=false; -if haskey(WEAVE_ARGS, :name_old) +if WEAVE_ARGS[:include_typed_code] && haskey(WEAVE_ARGS, :name_old) # We want to compare the generated code to the previous version. import DiffUtils typed_old = deserialize(joinpath("results", WEAVE_ARGS[:name_old], "$(m.name).jls")); diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index 614afb2e9..448e6d68a 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -1,18 +1,19 @@ -# Benchmarks +`j display("text/markdown", "## $(WEAVE_ARGS[:name]) ##")` -## Setup +### Setup ### ```julia using BenchmarkTools, DynamicPPL, Distributions, Serialization ``` ```julia -import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child +using DynamicPPLBenchmarks +using DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child ``` -## Models +### Models ### -### `demo1` +#### `demo1` #### ```julia @model function demo1(x) @@ -23,14 +24,14 @@ import DynamicPPLBenchmarks: time_model_def, make_suite, typed_code, weave_child end model_def = demo1; -data = 1.0; +data = (1.0, ); ``` ```julia; results="markup"; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` -### `demo2` +#### `demo2` #### ```julia @model function demo2(y) @@ -46,14 +47,14 @@ weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) end model_def = demo2; -data = rand(0:1, 10); +data = (rand(0:1, 10), ); ``` ```julia; results="markup"; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` -### `demo3` +#### `demo3` #### ```julia @model function demo3(x) @@ -88,9 +89,92 @@ N = 30 μs = [-3.5, 0.0] # Construct the data points. -data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); +data = (mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2), ); ``` ```julia; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` + +#### `demo4`: lots of variables + +```julia +@model function demo4_1k(::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, 1_000) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4_1k +data = (); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia +@model function demo4_10k(::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, 10_000) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4_10k +data = (); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia +@model function demo4_100k(::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, 100_000) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4_100k +data = (); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + + +#### `demo4_dotted`: `.~` for large number of variables + +```julia +@model function demo4_100k_dotted(::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, 100_000) + x .~ Normal(m, 1.0) +end + +model_def = demo4_100k_dotted +data = (); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia; echo=false +if haskey(WEAVE_ARGS, :name_old) + display(MIME"text/markdown"(), "## Comparison with $(WEAVE_ARGS[:name_old]) ##") +end +``` + +```julia; echo=false; displaysize=(30, 200) +if haskey(WEAVE_ARGS, :name_old) + DynamicPPLBenchmarks.judgementtable(WEAVE_ARGS[:name], WEAVE_ARGS[:name_old]) +end +``` diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 362a8940f..2cb55ea8b 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -151,6 +151,7 @@ function weave_benchmarks( name=default_name(; include_commit_id=include_commit_id), name_old=nothing, include_typed_code=false, + seconds=10, doctype="github", outpath="results/$(name)/", kwargs..., @@ -159,6 +160,7 @@ function weave_benchmarks( :benchmarkbody => benchmarkbody, :name => name, :include_typed_code => include_typed_code, + :seconds => seconds ) if !isnothing(name_old) args[:name_old] = name_old @@ -168,4 +170,6 @@ function weave_benchmarks( return Weave.weave(input, doctype; out_path=outpath, args=args, kwargs...) end +include("tables.jl") + end # module diff --git a/benchmarks/src/tables.jl b/benchmarks/src/tables.jl new file mode 100644 index 000000000..ba9f8cc0c --- /dev/null +++ b/benchmarks/src/tables.jl @@ -0,0 +1,187 @@ +using BenchmarkTools, Tables, PrettyTables, DrWatson + +####################### +# `TrialJudgementRow` # +####################### +struct TrialJudgementRow{names,Textras} <: Tables.AbstractRow + group::String + judgement::BenchmarkTools.TrialJudgement + extras::NamedTuple{names,Textras} +end + +function TrialJudgementRow(group::String, judgement::BenchmarkTools.TrialJudgement) + return TrialJudgementRow(group, judgement, NamedTuple()) +end + +function Tables.columnnames(::Type{TrialJudgementRow}) + return ( + :group, + :time, + :time_judgement, + :gctime, + :memory, + :memory_judgement, + :allocs, + :time_tolerance, + :memory_tolerance + ) +end +# Dispatch needs to include all type-parameters because Tables.jl is a bit too aggressive +# when it comes to overloading this. +function Tables.columnnames(::Type{TrialJudgementRow{names,Textras}}) where {names,Textras} + return (Tables.columnnames(TrialJudgementRow)..., names...) +end +Tables.columnnames(row::TrialJudgementRow) = Tables.columnnames(typeof(row)) +Tables.getcolumn(row::TrialJudgementRow, i::Int) = Tables.getcolumn(row, Tables.columnnames(row)[i]) +function Tables.getcolumn(row::TrialJudgementRow, name::Symbol) + # NOTE: We need to use `getfield` below because `getproperty` is overloaded by Tables.jl + # and so we'll get a `StackOverflowError` if we try to do something like `row.group`. + return if name === :group + getfield(row, name) + elseif name in Tables.columnnames(TrialJudgementRow) + # `name` is one of the default columns + j = getfield(row, :judgement) + if name === :time_judgement + j.time + elseif name === :memory_judgement + j.memory + elseif name === :time_tolerance + params(j).time_tolerance + elseif name === :memory_tolerance + params(j).memory_tolerance + else + # Defer the rest to the `TrialRatio`. + r = j.ratio + getfield(r, name) + end + else + # One of `row.extras` + extras = getfield(row, :extras) + getfield(extras, name) + end +end + +Tables.istable(rows::Vector{<:TrialJudgementRow}) = true + +Tables.rows(rows::Vector{<:TrialJudgementRow}) = rows +Tables.rowaccess(rows::Vector{<:TrialJudgementRow}) = true + +# Because DataFrames.jl doesn't respect the `columnaccess`: +# https://github.com/JuliaData/DataFrames.jl/blob/2b9f6673547259bab9fb3bf3b5224eebc7b11ecd/src/other/tables.jl#L48-L61. +Tables.columnaccess(rows::Vector{<:TrialJudgementRow}) = true +function Tables.columns(rows::Vector{<:TrialJudgementRow}) + return (; ((name, getproperty.(rows, name)) for name in Tables.columnnames(eltype(rows)))...) +end + +######################### +# PrettyTables.jl usage # +######################### +function make_highlighter_judgement(isgood) + function highlighter_judgement(data::Vector{<:TrialJudgementRow}, i, j) + names = Tables.columnnames(eltype(data)) + name = names[j] + row = data[i] + x = row[j] + + if name === :time || name === :time_judgement + j = row[:time_judgement] + if j === :improvement + return isgood + elseif j === :regression + return !isgood + end + elseif name === :memory || name === :memory_judgement + j = row[:memory_judgement] + if j === :improvement + return isgood + elseif j === :regression + return !isgood + end + end + + return false + end + + return highlighter_judgement +end + +function make_formatter(data::Vector{<:TrialJudgementRow}) + names = Tables.columnnames(eltype(data)) + function formatter_judgement(x, i, j) + name = names[j] + + if name in (:time, :memory, :allocs, :gctime) + return BenchmarkTools.prettydiff(x) + elseif name in (:time_tolerance, :memory_tolerance) + return BenchmarkTools.prettypercent(x) + end + + return x + end + + return formatter_judgement +end + +function Base.show(io::IO, ::MIME"text/plain", rows::Vector{<:TrialJudgementRow}) + hgood = Highlighter(make_highlighter_judgement(true), foreground=:green, bold=true) + hbad = Highlighter(make_highlighter_judgement(false), foreground=:red, bold=true) + formatter = make_formatter(rows) + pretty_table( + io, rows, + highlighters=(hgood, hbad), + formatters=(formatter, ) + ) +end + +function Base.show(io::IO, ::MIME"text/html", rows::Vector{<:TrialJudgementRow}) + hgood = HTMLHighlighter(make_highlighter_judgement(true), HTMLDecoration(color="green", font_weight="bold")) + hbad = HTMLHighlighter(make_highlighter_judgement(false), HTMLDecoration(color="red", font_weight="bold")) + formatter = make_formatter(rows) + pretty_table( + io, rows, + backend=Val(:html), + highlighters=(hgood, hbad), + formatters=(formatter, ), + tf=PrettyTables.tf_html_minimalist + ) +end + +######################################################### +# Make it more convenient to load benchmarks into table # +######################################################### +function judgementtable( + results::AbstractVector, results_old::AbstractVector, + extras=fill(NamedTuple(), length(results)); + stat = minimum +) + @assert length(results_old) == length(results) "benchmarks have different lengths" + + return collect( + TrialJudgementRow( + groupname, + judge(stat(results[i][groupname]), stat(results_old[i][groupname])), + extras[i] + ) + for i in eachindex(results) + for groupname in keys(results[i]) + ) +end + +function judgementtable(name::String, name_old::String; kwargs...) + model_names = map(filter(endswith("_benchmarks.json"), readdir(projectdir("results", name)))) do x + # Strip the suffix. + x[1:end-5] + end + + results = [] + results_old = [] + for model_name in model_names + append!(results, BenchmarkTools.load(projectdir("results", name, "$(model_name).json"))) + append!(results_old, BenchmarkTools.load(projectdir("results", name_old, "$(model_name).json"))) + end + + extras = [(model_name = model_name, ) for model_name in model_names] + + return judgementtable(results, results_old, extras; kwargs...) +end + From d0a08f694752ccc9fd55cd2cb9f9f53c1ce6c96c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:13:39 +0100 Subject: [PATCH 109/216] fixed some issues and added support for usage of Dict in SimpleVarInfo --- src/simple_varinfo.jl | 67 ++++++++++++++++++++++++++++--------------- src/varinfo.jl | 20 +++++++++++++ 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 87d9b0516..879b40a65 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -21,32 +21,27 @@ end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(nothing) +SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple()) SimpleVarInfo() = SimpleVarInfo{Float64}() -# Interaction with `VarInfo` -SimpleVarInfo(vi::TypedVarInfo) = SimpleVarInfo{eltype(getlogp(vi))}(vi) -function SimpleVarInfo{T}(vi::VarInfo{<:NamedTuple{names}}) where {T<:Real,names} - vals = map(names) do n - let md = getfield(vi.metadata, n) - x = map(enumerate(md.ranges)) do (i, r) - reconstruct(md.dists[i], md.vals[r]) - end - - # TODO: Doesn't support batches of `MultivariateDistribution`? - length(x) == 1 ? x[1] : x - end - end - - return SimpleVarInfo{T}(NamedTuple{names}(vals)) -end - +# Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) return svi end +# Constructor from `VarInfo`. +function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} + return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...) +end +function SimpleVarInfo{T}( + vi::VarInfo{<:NamedTuple{names}}, ::Type{D} +) where {T<:Real,names,D} + values = values_as(vi, D) + return SimpleVarInfo{T}(values) +end + getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) @@ -67,9 +62,28 @@ function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} return _getindex(value, inds) end +# `NamedTuple` function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - return _getvalue(vi.θ, Val{sym}(), vn.indexing) + # If `sym` is found in `vi.θ` we assume it will be of correct + # shape to support `getindex` for `vn.indexing`. + # If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`. + # This means that we support both the following cases: + # 1. `x[1]` has been provided by the user and can be assumed to be + # of shape that allows us to call `_getvalue` on it. + # 2. `x[1]` was not provided by the user, e.g. possibly obtained by + # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. + return if haskey(vi.θ, sym) + _getvalue(vi.θ, Val{sym}(), vn.indexing) + else + getproperty(vi.θ, Symbol(vn)) + end end + +# `Dict` +function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName) + return vi.θ[vn] +end + # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) @@ -96,15 +110,22 @@ function Base.eltype( return T end +# `NamedTuple` function push!!( - vi::SimpleVarInfo{Nothing}, vn::VarName{sym,Tuple{}}, value, dist::Distribution + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = NamedTuple{(sym,)}((value,)) + @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) + @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) +end + +# `Dict` +function push!!(vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) + vi.θ[vn] = r + return vi end # Context implementations diff --git a/src/varinfo.jl b/src/varinfo.jl index 007e3cf3d..af5cb0622 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1491,3 +1491,23 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, return indices end + +""" + values_as(vi::TypedVarInfo, ::Type{NamedTuple}) + values_as(vi::TypedVarInfo, ::Type{Dict}) + +Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if +""" +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return Dict(iter) +end + +function values_from_metadata(md::Metadata) + return (vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for vn in md.vns) +end From ff75ddc0be85532259ace44d86e8f88436195ac5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:45:03 +0100 Subject: [PATCH 110/216] added docstring and improved indexing behvaior for SimpleVarInfo --- src/simple_varinfo.jl | 96 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 879b40a65..c01263ae1 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -6,13 +6,90 @@ using Setfield A simple wrapper of the parameters with a `logp` field for accumulation of the logdensity. -Currently only implemented for `NT <: NamedTuple`. +Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`. -## Notes +# Notes The major differences between this and `TypedVarInfo` are: 1. `SimpleVarInfo` does not require linearization. 2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` only supports evaluation. +3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either + a) no indexing is used in tilde-statements, or + b) the values have been specified with the corret shapes. + +# Examples +```jldoctest; setup=:(using Distributions, Random) +julia> @model function demo() + x = Vector{Float64}(undef, 2) + for i in eachindex(x) + x[i] ~ Normal() + end + return x + end +demo (generic function with 1 method) + +julia> m = demo(); + +julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); + +julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")` + # and thus accessing these values will be type-unstable and slower. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi +SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044) + +julia> # (×) SLOW!!! + DynamicPPL.getval(vi, @varname(x[1])) +0.14447203090358265 + +julia> # In addtion, we can only access varnames as they appear in the model! + DynamicPPL.getval(vi, @varname(x)) +ERROR: type NamedTuple has no field x +[...] + +julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +ERROR: type NamedTuple has no field x[1:2] +[...] + +julia> # In contrast, if we provide the container for `x`, the `vi` now only + # has the key `x` and we access parts of it using indices. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi +SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654],), -2.0573897507053474) + +julia> # (✓) Vroom, vroom! FAST!!! + DynamicPPL.getval(vi, @varname(x[1])) +-0.6538238172778861 + +julia> # We can also access arbitrary varnames pointing to `x`, e.g. + DynamicPPL.getval(vi, @varname(x)) +2-element Vector{Float64}: + -0.6538238172778861 + 0.10742338922309654 + +julia> DynamicPPL.getval(vi, @varname(x[1:2])) +2-element view(::Vector{Float64}, 1:2) with eltype Float64: + -0.6538238172778861 + 0.10742338922309654 + +julia> # The better way to handle sampling of variables involving indexing + # if one does not know the varnames, is to use a `Dict` as the container instead. + # Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e. + # how they appear in the model. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437, x[2] => -1.382335836121636), -3.4308773745351453) + +julia> # (✓) Sort of fast, but only possible at runtime. + DynamicPPL.getval(vi, @varname(x[1])) +1.1292246244328437 + +julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does + # not directly appear in the model. + DynamicPPL.getval(vi, @varname(x)) +ERROR: KeyError: key x not found +[...] + +julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +ERROR: KeyError: key x[1:2] not found +[...] +``` """ struct SimpleVarInfo{NT,T} <: AbstractVarInfo θ::NT @@ -73,7 +150,7 @@ function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} # 2. `x[1]` was not provided by the user, e.g. possibly obtained by # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. return if haskey(vi.θ, sym) - _getvalue(vi.θ, Val{sym}(), vn.indexing) + maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) else getproperty(vi.θ, Symbol(vn)) end @@ -117,9 +194,16 @@ function push!!( @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, dist::Distribution, gidset::Set{Selector} ) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) + # If the key is already there, we try to update in place. + return if haskey(vi.θ, sym) + current = _getvalue(vi.θ, Val{sym}(), vn.indexing) + current .= value + vi + else + @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) + end end # `Dict` From d29dd8f54fe3133e1f7813fff895c8a2fd983f0b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 04:45:36 +0100 Subject: [PATCH 111/216] formatting --- src/simple_varinfo.jl | 16 +++++++++++++--- src/varinfo.jl | 5 ++++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index c01263ae1..0e5d70213 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -189,12 +189,20 @@ end # `NamedTuple` function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Tuple{}}, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym,Tuple{}}, + value, + dist::Distribution, + gidset::Set{Selector}, ) where {sym} @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:NamedTuple}, + vn::VarName{sym}, + value, + dist::Distribution, + gidset::Set{Selector}, ) where {sym} # If the key is already there, we try to update in place. return if haskey(vi.θ, sym) @@ -207,7 +215,9 @@ function push!!( end # `Dict` -function push!!(vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!!( + vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) vi.θ[vn] = r return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3628ee199..6b7523fbf 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1511,5 +1511,8 @@ function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} end function values_from_metadata(md::Metadata) - return (vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for vn in md.vns) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) end From a72594f058e2203aab66c39ac532df24eae2bbfb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Aug 2021 05:07:41 +0100 Subject: [PATCH 112/216] dont allow sampling with indexing when using SimpleVarInfo with NamedTuple unless shapes are specified --- src/simple_varinfo.jl | 93 +++++++++++++++++-------------------------- 1 file changed, 36 insertions(+), 57 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0e5d70213..afc176a5d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -17,8 +17,11 @@ The major differences between this and `TypedVarInfo` are: b) the values have been specified with the corret shapes. # Examples -```jldoctest; setup=:(using Distributions, Random) +```jldoctest; setup=:(using Distributions) +julia> using StableRNGs + julia> @model function demo() + m ~ Normal() x = Vector{Float64}(undef, 2) for i in eachindex(x) x[i] ~ Normal() @@ -29,59 +32,46 @@ demo (generic function with 1 method) julia> m = demo(); -julia> ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); - -julia> # Notice how the resulting `vi` has keys `(var"x[1]", var"x[2]")` - # and thus accessing these values will be type-unstable and slower. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi -SimpleVarInfo{NamedTuple{(Symbol("x[1]"), Symbol("x[2]")), Tuple{Float64, Float64}}, Float64}((x[1] = 0.14447203090358265, x[2] = 0.21780448216717593), -1.8720325464921044) - -julia> # (×) SLOW!!! - DynamicPPL.getval(vi, @varname(x[1])) -0.14447203090358265 - -julia> # In addtion, we can only access varnames as they appear in the model! - DynamicPPL.getval(vi, @varname(x)) -ERROR: type NamedTuple has no field x -[...] +julia> rng = StableRNG(42); -julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) -ERROR: type NamedTuple has no field x[1:2] -[...] +julia> ### Sampling ### + ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); -julia> # In contrast, if we provide the container for `x`, the `vi` now only - # has the key `x` and we access parts of it using indices. +julia> # In the `NamedTuple` version we need to provide the place-holder values for + # the variablse which are using "containers", e.g. `Array`. + # In this case, this means that we need to specify `x` but not `m`. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo{NamedTuple{(:x,), Tuple{Vector{Float64}}}, Float64}((x = [-0.6538238172778861, 0.10742338922309654],), -2.0573897507053474) +SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931) julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) --0.6538238172778861 +1.6642061055583879 julia> # We can also access arbitrary varnames pointing to `x`, e.g. DynamicPPL.getval(vi, @varname(x)) 2-element Vector{Float64}: - -0.6538238172778861 - 0.10742338922309654 + 1.6642061055583879 + 1.796319600944139 julia> DynamicPPL.getval(vi, @varname(x[1:2])) 2-element view(::Vector{Float64}, 1:2) with eltype Float64: - -0.6538238172778861 - 0.10742338922309654 + 1.6642061055583879 + 1.796319600944139 + +julia> # (×) If we don't provide the container... + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi +ERROR: type NamedTuple has no field x +[...] -julia> # The better way to handle sampling of variables involving indexing - # if one does not know the varnames, is to use a `Dict` as the container instead. - # Notice that here the keys are the same as for the `SimpleVarInfo()` scenario, i.e. - # how they appear in the model. +julia> # If one does not know the varnames, we can use a `Dict` instead. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi -SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.1292246244328437, x[2] => -1.382335836121636), -3.4308773745351453) +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237) julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) -1.1292246244328437 +1.192696983568277 -julia> # And as in the `SimpleVarInfo()` case, we cannot access varnames that does - # not directly appear in the model. +julia> # In addtion, we can only access varnames as they appear in the model! DynamicPPL.getval(vi, @varname(x)) ERROR: KeyError: key x not found [...] @@ -136,24 +126,15 @@ end function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} # Use `getproperty` instead of `getfield` value = getproperty(nt, sym) + # Note that this will return a `view`, even if the resulting value is 0-dim. + # This makes it possible to call `setindex!` on the result later to update + # in place even in the case where are retrieving a single element, e.g. `x[1]`. return _getindex(value, inds) end # `NamedTuple` -function getval(vi::SimpleVarInfo, vn::VarName{sym}) where {sym} - # If `sym` is found in `vi.θ` we assume it will be of correct - # shape to support `getindex` for `vn.indexing`. - # If `sym` is NOT found in `vi.θ`, we try `Symbol(vn)`. - # This means that we support both the following cases: - # 1. `x[1]` has been provided by the user and can be assumed to be - # of shape that allows us to call `_getvalue` on it. - # 2. `x[1]` was not provided by the user, e.g. possibly obtained by - # sampling with a `SimpleVarInfo` which then produced the key `var"x[1]"`. - return if haskey(vi.θ, sym) - maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) - else - getproperty(vi.θ, Symbol(vn)) - end +function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} + return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) end # `Dict` @@ -204,14 +185,12 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - # If the key is already there, we try to update in place. - return if haskey(vi.θ, sym) - current = _getvalue(vi.θ, Val{sym}(), vn.indexing) - current .= value - vi - else - @set vi.θ = merge(vi.θ, NamedTuple{(Symbol(vn),)}((value,))) - end + # We update in place. + # We need a view into the array, hence we call `_getvalue` directly + # rather than `getval`. + current = _getvalue(vi.θ, Val{sym}(), vn.indexing) + current .= value + return vi end # `Dict` From be35be0d3bac2f5c1162e7631ba921cc4a4c3c80 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 04:01:07 +0100 Subject: [PATCH 113/216] _setval_kernel and others are only supported by VarInfo atm --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 6b7523fbf..cdfd731e7 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1398,7 +1398,7 @@ function setval!( return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1479,7 +1479,7 @@ function setval_and_resample!( return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) From 4d4eeb3fd5ca4ccb0bcc4a608bce96b58964168f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 01:29:43 +0100 Subject: [PATCH 114/216] fixed typo in comment --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index afc176a5d..569c1bc3a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -273,6 +273,6 @@ function dot_assume( return value, lp, vi end -# HACK: Allows us to re-use the impleemntation of `dot_tilde`, etc. for literals. +# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi From b7862a80acf67dcaaf3d3d5445b67759d45f4245 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 02:29:54 +0100 Subject: [PATCH 115/216] added more values_as impls --- src/simple_varinfo.jl | 4 ++++ src/varinfo.jl | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 569c1bc3a..bf3598162 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -276,3 +276,7 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi + +values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.θ)) +values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.θ)) +values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.θ diff --git a/src/varinfo.jl b/src/varinfo.jl index cdfd731e7..1ac1a608d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1495,11 +1495,17 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) end """ - values_as(vi::TypedVarInfo, ::Type{NamedTuple}) - values_as(vi::TypedVarInfo, ::Type{Dict}) + values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) + values_as(vi::AbstractVarInfo, ::Type{Dict}) -Return values in `vi` as the specified type, e.g. `NamedTuple` is returned if +Return values in `vi` as the specified type. """ +function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) + iter = values_from_metadata(vi.metadata) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end +values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata)) + function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) @@ -1516,3 +1522,10 @@ function values_from_metadata(md::Metadata) vn in md.vns ) end + +function values_from_metadata(md::Metadata) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) +end From 39619cdf458ff18839af4c867314c93097eaff36 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 02:57:31 +0100 Subject: [PATCH 116/216] removed redundant values_from_metadata --- src/varinfo.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 1ac1a608d..8894c8cb2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1522,10 +1522,3 @@ function values_from_metadata(md::Metadata) vn in md.vns ) end - -function values_from_metadata(md::Metadata) - return ( - vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for - vn in md.vns - ) -end From e67901baaf1d65e8161b3ddf4f35cfe02292d0c0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 17 Aug 2021 01:58:16 +0100 Subject: [PATCH 117/216] fixed bug in push!! for SimpleVarInfo --- src/simple_varinfo.jl | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index bf3598162..b1f6afae4 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -123,18 +123,29 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -function _getvalue(nt::NamedTuple, ::Val{sym}, inds=()) where {sym} +# TODO: Get rid of this once we have lenses. +_getindex_view(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) +_getindex_view(x, inds::Tuple{}) = x + +# TODO: Get rid of this once we have lenses. +function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym, Tuple{}}) where {sym} + return merge(nt, NamedTuple{(sym, )}((val, ))) +end +function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} # Use `getproperty` instead of `getfield` value = getproperty(nt, sym) # Note that this will return a `view`, even if the resulting value is 0-dim. # This makes it possible to call `setindex!` on the result later to update # in place even in the case where are retrieving a single element, e.g. `x[1]`. - return _getindex(value, inds) + dest_view = _getindex_view(value, vn.indexing) + dest_view .= val + + return nt end # `NamedTuple` function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} - return maybe_unwrap_view(_getvalue(vi.θ, Val{sym}(), vn.indexing)) + return _getvalue(vi.θ, Val{sym}(), vn.indexing) end # `Dict` @@ -188,8 +199,7 @@ function push!!( # We update in place. # We need a view into the array, hence we call `_getvalue` directly # rather than `getval`. - current = _getvalue(vi.θ, Val{sym}(), vn.indexing) - current .= value + _setvalue!!(vi.θ, value, vn) return vi end From d7dad312d201b0223a4258fb638736f6571e55c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 19 Aug 2021 15:51:34 +0100 Subject: [PATCH 118/216] forgot which branch Im on --- Project.toml | 1 + src/DynamicPPL.jl | 1 + src/simple_varinfo.jl | 8 +++----- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index f6991a1e4..6fd73c105 100644 --- a/Project.toml +++ b/Project.toml @@ -22,5 +22,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Setfield = "0.7" ZygoteRules = "0.2" julia = "1.3" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b1a70f700..bf7109ad0 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules using BangBang: BangBang +using Setfield: Setfield using Random: Random diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 450d0fdfd..38d05ff85 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -1,5 +1,3 @@ -using Setfield - """ SimpleVarInfo{NT,T} <: AbstractVarInfo @@ -128,7 +126,7 @@ _getindex_view(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(in _getindex_view(x, inds::Tuple{}) = x # TODO: Get rid of this once we have lenses. -function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym, Setfield.IdentityLens}) where {sym} +function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym,Tuple{}}) where {sym} return merge(nt, NamedTuple{(sym, )}((val, ))) end function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} @@ -182,12 +180,12 @@ end # `NamedTuple` function push!!( vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,Setfield.IdentityLens}, + vn::VarName{sym,Tuple{}}, value, dist::Distribution, gidset::Set{Selector}, ) where {sym} - @set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) + Setfield.@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( vi::SimpleVarInfo{<:NamedTuple}, From 6ec2d29eebcdec9cd289dfe5d22f2e45440f1fbb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 21 Aug 2021 12:29:54 +0100 Subject: [PATCH 119/216] added handling of short defs in replace_returns and more docstrings --- src/compiler.jl | 69 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7c3ed1f91..a51dba3fd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -464,10 +464,39 @@ function generate_dot_tilde(left, right) end end +""" + isfuncdef(expr) + +Return `true` if `expr` is any form of function definition, and `false` otherwise. +""" +function isfuncdef(e::Expr) + return if Meta.isexpr(e, :function) + # Classic `function f(...)` + true + elseif Meta.isexpr(e, :->) + # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs. + true + elseif Meta.isexpr(e, :=) && Meta.isexpr(e.args[1], :call) + # Short function defs, e.g. `f(args...) = ...`. + true + else + false + end +end + +""" + replace_returns(expr) + +Return `Expr` with all `return ...` statements replaced with +`return ..., DynamicPPL.return_values(__varinfo__)`. + +Note that this method will _not_ replace `return` statements within function +definitions. This is checked using [`isfuncdef`](@ref). +""" replace_returns(e) = e replace_returns(e::Symbol) = e function replace_returns(e::Expr) - if Meta.isexpr(e, :function) || Meta.isexpr(e, :->) + if isfuncdef(e) return e end @@ -487,6 +516,40 @@ function replace_returns(e::Expr) return Expr(e.head, map(replace_returns, e.args)...) end +""" + return_values(retval, varinfo) + +Return `(retval, varinfo)` if `retval` is not a `Tuple` with second +component being a `AbstractVarInfo`. + +Used together with [`replace_returns`](@ref), it handles the following case. + +# Example + +Suppose the following is the return-value: + +```julia +return x ~ Normal() +``` + +Without `return_values`, once expanded in [`generated_mainbody!`](@ref), this would be + +```julia +return (x, __varinfo__ = tilde_assume!!(...)), __varinfo__ +``` + +i.e. the return-value of the model would end up `(x, __varinfo__), __varinfo__` +which in turn would lead to a `(::Model)(args...)` call returning `(x, __varinfo__)`, +breaking with the expectation of the user. + +In such a scenario `return_values` effectively results in the following + +```julia +return x, __varinfo__ = tilde_assume!!(...) +``` + +preserving user expectation, as desired. +""" return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) return_values(retval::Tuple{<:Any,<:AbstractVarInfo}, ::AbstractVarInfo) = retval @@ -534,6 +597,10 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. + # NOTE: We need to replace statements of the form `return ...` with + # `return DynamicPPL.return_values(..., __varinfo__)` to ensure that the second + # element in the returned value is always the most up-to-date `__varinfo__`. + # See the docstrings of `replace_returns` and `return_values` for more info. evaluatordef[:body] = replace_returns(make_returns_explicit!(modelinfo[:body])) ## Build the model function. From dfd9dc571178babb5a6811bb3defdfaec33667c8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 01:50:57 +0100 Subject: [PATCH 120/216] fixed bug in generate_tilde introduced in a merge --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index a51dba3fd..481e61454 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -403,7 +403,7 @@ function generate_tilde(left, right) $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.tilde_observe!!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), @@ -452,7 +452,7 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.dot_tilde_observe!!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), From 18ed81725fa519e094a1b4ed1f4039f67af319ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 01:59:08 +0100 Subject: [PATCH 121/216] fixed a bug in isfuncdef --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 481e61454..c446dda6d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -476,7 +476,7 @@ function isfuncdef(e::Expr) elseif Meta.isexpr(e, :->) # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs. true - elseif Meta.isexpr(e, :=) && Meta.isexpr(e.args[1], :call) + elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call) # Short function defs, e.g. `f(args...) = ...`. true else From b61a9bef3dda04ee6bf140c765e8e3dc30ca34e0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 03:16:58 +0100 Subject: [PATCH 122/216] fixed tests --- src/context_implementations.jl | 12 ++++++------ src/loglikelihoods.jl | 3 +-- src/prob_macro.jl | 2 +- src/simple_varinfo.jl | 22 +++++++++++----------- src/varinfo.jl | 17 ++++++++++++----- test/loglikelihoods.jl | 4 ++-- test/runtests.jl | 2 +- 7 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 125f52580..2a18563c7 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -233,7 +233,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi, + vi::VarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. @@ -468,7 +468,7 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, var::AbstractArray, - vi, + vi::VarInfo, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions @@ -483,7 +483,7 @@ end function get_and_set_val!( rng, - vi, + vi::VarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -516,7 +516,7 @@ end function get_and_set_val!( rng, - vi, + vi::VarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -547,7 +547,7 @@ function get_and_set_val!( end function set_val!( - vi, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix + vi::VarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) @@ -556,7 +556,7 @@ function set_val!( return val end function set_val!( - vi, + vi::VarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, val::AbstractArray, diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 26055bef0..89b3d9e6d 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -86,7 +86,6 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v # hence we need the `logp` for each of them. Broadcasting the univariate # `tilde_obseve` does exactly this. logps = _pointwise_tilde_observe(context.context, right, left, vi) - acclogp!(vi, sum(logps)) # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. _, _, vns = unwrap_right_left_vns(right, left, vn) @@ -95,7 +94,7 @@ function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, v push!(context, vn, logp) end - return left, acclogp!!(vi, logp) + return left, acclogp!!(vi, sum(logps)) end # FIXME: This is really not a good approach since it needs to stay in sync with diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..84497aef0 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,7 +146,7 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - model(vi, SampleFromPrior(), PriorContext(left)) + _, vi = DynamicPPL.evaluate(model, vi, SampleFromPrior(), PriorContext(left)) return getlogp(vi) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 38d05ff85..621f20502 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -33,28 +33,28 @@ julia> m = demo(); julia> rng = StableRNG(42); julia> ### Sampling ### - ctx = SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()); + ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variablse which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [1.6642061055583879, 1.796319600944139], m = -0.16796295277202952), -5.769094411622931) +SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [0.4471218424633827, 1.3736306979834252], m = -0.6702516921145671), -4.024823883230379) julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) -1.6642061055583879 +0.4471218424633827 julia> # We can also access arbitrary varnames pointing to `x`, e.g. DynamicPPL.getval(vi, @varname(x)) 2-element Vector{Float64}: - 1.6642061055583879 - 1.796319600944139 + 0.4471218424633827 + 1.3736306979834252 julia> DynamicPPL.getval(vi, @varname(x[1:2])) 2-element view(::Vector{Float64}, 1:2) with eltype Float64: - 1.6642061055583879 - 1.796319600944139 + 0.4471218424633827 + 1.3736306979834252 julia> # (×) If we don't provide the container... _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi @@ -63,11 +63,11 @@ ERROR: type NamedTuple has no field x julia> # If one does not know the varnames, we can use a `Dict` instead. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi -SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => 1.192696983568277, x[2] => 0.4914514300738121, m => 0.25572200616753643), -3.6215377732004237) +SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => -1.019202452456547, x[2] => -0.7935128416361353, m => 0.683947930996541), -3.8249261202386906) julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) -1.192696983568277 +-1.019202452456547 julia> # In addtion, we can only access varnames as they appear in the model! DynamicPPL.getval(vi, @varname(x)) @@ -142,8 +142,8 @@ function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} end # `NamedTuple` -function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}) where {sym} - return _getvalue(vi.θ, Val{sym}(), vn.indexing) +function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) + return _getvalue(vi.θ, vn) end # `Dict` diff --git a/src/varinfo.jl b/src/varinfo.jl index 8894c8cb2..2682cf856 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -369,7 +369,13 @@ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!!(vi, vn, "trans") : unset_flag!!(vi, vn, "trans") + if trans + set_flag!!(vi, vn, "trans") + else + unset_flag!!(vi, vn, "trans") + end + + return vi end """ @@ -962,8 +968,8 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) -setindex!!(vi::AbstractVarInfo, val, vn::VarName) = setindex!(vi, val, vn) +setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) +setindex!!(vi::AbstractVarInfo, val, vn::VarName) = (setindex!(vi, val, vn); return vi) """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) @@ -978,7 +984,7 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` ranges = _getranges(vi, spl) _setindex!(vi.metadata, val, ranges) - return val + return vi end # Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. @generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} @@ -1189,7 +1195,8 @@ end Set `vn`'s value for `flag` to `false` in `vi`. """ function unset_flag!!(vi::VarInfo, vn::VarName, flag::String) - return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false + getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false + return vi end """ diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index bbb32d800..568a832ae 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -69,7 +69,7 @@ end @model function gdemo9() # Submodel prior - @submodel m ~ _prior_dot_assume() + @submodel m = _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end @@ -110,7 +110,7 @@ const gdemo_models = ( ) @testset "loglikelihoods.jl" begin - for m in gdemo_models + @testset "$(m.name)" for m in gdemo_models vi = VarInfo(m) vns = vi.metadata.m.vns diff --git a/test/runtests.jl b/test/runtests.jl index bb2ae579c..64bb84107 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("test_util.jl") include("threadsafe.jl") - include("serialization.jl") + # include("serialization.jl") include("loglikelihoods.jl") end From 8d2dc71272842cdace065a16ac82fa1548cc3282 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Aug 2021 03:17:49 +0100 Subject: [PATCH 123/216] formatting --- src/context_implementations.jl | 5 ++++- src/simple_varinfo.jl | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 2a18563c7..16e7151f2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -547,7 +547,10 @@ function get_and_set_val!( end function set_val!( - vi::VarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix + vi::VarInfo, + vns::AbstractVector{<:VarName}, + dist::MultivariateDistribution, + val::AbstractMatrix, ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 621f20502..04fd062e0 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -127,7 +127,7 @@ _getindex_view(x, inds::Tuple{}) = x # TODO: Get rid of this once we have lenses. function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym,Tuple{}}) where {sym} - return merge(nt, NamedTuple{(sym, )}((val, ))) + return merge(nt, NamedTuple{(sym,)}((val,))) end function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} # Use `getproperty` instead of `getfield` From 9cff93febaf704365d08e3005c3fa9e530b11804 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Aug 2021 15:12:55 +0100 Subject: [PATCH 124/216] uncomment mistakenly commented code --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 64bb84107..bb2ae579c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("test_util.jl") include("threadsafe.jl") - # include("serialization.jl") + include("serialization.jl") include("loglikelihoods.jl") end From af6427ab1b3225ba41606a4acc24743634f54231 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 01:40:45 +0100 Subject: [PATCH 125/216] bumped version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cdff210e4..2e0634861 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.14.2" +version = "0.15.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 7b5dbcebcf16d1db0338d6d093341c569bd40c68 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 02:00:32 +0100 Subject: [PATCH 126/216] updated doctests --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 04fd062e0..5759824d7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -26,7 +26,7 @@ julia> @model function demo() end return x end -demo (generic function with 1 method) +demo (generic function with 2 methods) julia> m = demo(); From a6e2ffb581ea9f142fc3ec9f458d798056a790ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:03:11 +0100 Subject: [PATCH 127/216] dont carry over bang-bang versions that we dont need for general varinfos --- src/DynamicPPL.jl | 10 -------- src/compat/ad.jl | 2 +- src/context_implementations.jl | 36 ++++++++++++++--------------- src/simple_varinfo.jl | 3 +-- src/varinfo.jl | 42 +++++++++++++++++----------------- 5 files changed, 41 insertions(+), 52 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bf7109ad0..5494a6d0d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -172,14 +172,4 @@ include("submodel_macro.jl") @deprecate link!(vi, spl) link!!(vi, spl) @deprecate invlink!(vi, spl) invlink!!(vi, spl) -@deprecate set_flag!(vi, vn, flag) set_flag!!(vi, vn, flag) -@deprecate unset_flag!(vi, vn, flag) unset_flag!!(vi, vn, flag) - -@deprecate settrans!(vi, trans, vn) settrans!!(vi, trans, vn) - -@deprecate setall!(vi, val) setall!!(vi, val) - -@deprecate setgid!(vi, gid, vn) setgid!!(vi, gid, vn) -@deprecate updategid!(vi, vn, spl) updategid!!(vi, vn, spl) - end # module diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 664ce2b33..4fd2830b3 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -3,7 +3,7 @@ ChainRulesCore.@non_differentiable push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) -ChainRulesCore.@non_differentiable updategid!!( +ChainRulesCore.@non_differentiable updategid!( vi::AbstractVarInfo, vn::VarName, spl::Sampler ) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 16e7151f2..73738d459 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -71,7 +71,7 @@ function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) vi = setindex!!( vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn ) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, inds, vi) end @@ -88,7 +88,7 @@ function tilde_assume( vi = setindex!!( vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn ) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end @@ -98,7 +98,7 @@ function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi = setindex!!( vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn ) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end @@ -115,7 +115,7 @@ function tilde_assume( vi = setindex!!( vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn ) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end @@ -145,7 +145,7 @@ By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates t probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, inds, vi) - value, logp = tilde_assume(context, right, vn, inds, vi) + value, logp, vi = tilde_assume(context, right, vn, inds, vi) return value, acclogp!!(vi, logp) end @@ -238,10 +238,10 @@ function assume( if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!!(vi, vn, "del") + unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) vi[vn] = vectorize(dist, r) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) else r = vi[vn] @@ -249,7 +249,7 @@ function assume( else r = init(rng, dist, sampler) push!!(vi, vn, r, dist, sampler) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) @@ -319,7 +319,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!.(Ref(vi), false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) @@ -339,7 +339,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!.(Ref(vi), false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) @@ -360,7 +360,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!.(Ref(vi), false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) @@ -380,7 +380,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!!.(Ref(vi), false, _vns) + settrans!.(Ref(vi), false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) @@ -492,12 +492,12 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!!(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] vi[vn] = vectorize(dist, r[:, i]) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -508,7 +508,7 @@ function get_and_set_val!( for i in 1:n vn = vns[i] push!!(vi, vn, r[:, i], dist, spl) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) end end return r @@ -524,14 +524,14 @@ function get_and_set_val!( if haskey(vi, vns[1]) # Always overwrite the parameters with new ones for `SampleFromUniform`. if spl isa SampleFromUniform || is_flagged(vi, vns[1], "del") - unset_flag!!(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists vi[vn] = vectorize(dist, r[i]) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -541,7 +541,7 @@ function get_and_set_val!( f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) push!!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!!.(Ref(vi), false, vns) + settrans!.(Ref(vi), false, vns) end return r end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 5759824d7..f7fa27578 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -229,7 +229,6 @@ function assume( ) value = init(rng, dist, sampler) vi = push!!(vi, vn, value, dist, sampler) - vi = settrans!!(vi, false, vn) return value, Distributions.loglikelihood(dist, value), vi end @@ -283,7 +282,7 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleVarInfo) = nothing -settrans!!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = vi +settrans!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = nothing values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.θ)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.θ)) diff --git a/src/varinfo.jl b/src/varinfo.jl index db8296e2f..7c4858c6b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -359,18 +359,18 @@ Return the set of sampler selectors associated with `vn` in `vi`. getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ - settrans!!(vi::VarInfo, trans::Bool, vn::VarName) + settrans!(vi::VarInfo, trans::Bool, vn::VarName) Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ -function settrans!!(vi::AbstractVarInfo, trans::Bool, vn::VarName) +function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) if trans - set_flag!!(vi, vn, "trans") + set_flag!(vi, vn, "trans") else - unset_flag!!(vi, vn, "trans") + unset_flag!(vi, vn, "trans") end - return vi + return nothing end """ @@ -513,11 +513,11 @@ end end """ - set_flag!!(vi::VarInfo, vn::VarName, flag::String) + set_flag!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `true` in `vi`. """ -function set_flag!!(vi::VarInfo, vn::VarName, flag::String) +function set_flag!(vi::VarInfo, vn::VarName, flag::String) getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true return vi end @@ -638,11 +638,11 @@ Base.keys(vi::UntypedVarInfo) = keys(vi.metadata.idcs) end """ - setgid!!(vi::VarInfo, gid::Selector, vn::VarName) + setgid!(vi::VarInfo, gid::Selector, vn::VarName) Add `gid` to the set of sampler selectors associated with `vn` in `vi`. """ -function setgid!!(vi::VarInfo, gid::Selector, vn::VarName) +function setgid!(vi::VarInfo, gid::Selector, vn::VarName) return push!(getmetadata(vi, vn).gids[getidx(vi, vn)], gid) end @@ -757,7 +757,7 @@ function link!!(vi::UntypedVarInfo, spl::Sampler) vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!!(vi, true, vn) + settrans!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -793,7 +793,7 @@ end ), vn, ) - settrans!!(vi, true, vn) + settrans!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -824,7 +824,7 @@ function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn, ) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -862,7 +862,7 @@ end ), vn, ) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1185,11 +1185,11 @@ function is_flagged(vi::VarInfo, vn::VarName, flag::String) end """ - unset_flag!!(vi::VarInfo, vn::VarName, flag::String) + unset_flag!(vi::VarInfo, vn::VarName, flag::String) Set `vn`'s value for `flag` to `false` in `vi`. """ -function unset_flag!!(vi::VarInfo, vn::VarName, flag::String) +function unset_flag!(vi::VarInfo, vn::VarName, flag::String) getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false return vi end @@ -1250,14 +1250,14 @@ end end """ - updategid!!(vi::VarInfo, vn::VarName, spl::Sampler) + updategid!(vi::VarInfo, vn::VarName, spl::Sampler) Set `vn`'s `gid` to `Set([spl.selector])`, if `vn` does not have a sampler selector linked and `vn`'s symbol is in the space of `spl`. """ -function updategid!!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) +function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) if inspace(vn, getspace(spl)) - setgid!!(vi, spl.selector, vn) + setgid!(vi, spl.selector, vn) end end @@ -1405,7 +1405,7 @@ function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return indices @@ -1486,11 +1486,11 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!!(vi, false, vn) + settrans!(vi, false, vn) else # Ensures that we'll resample the variable corresponding to `vn` if we run # the model on `vi` again. - set_flag!!(vi, vn, "del") + set_flag!(vi, vn, "del") end return indices From 49c11579920eb1d84b8675b7e46606ee6957d587 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:08:16 +0100 Subject: [PATCH 128/216] Apply suggestions from @phipsgabler Co-authored-by: Philipp Gabler --- src/compiler.jl | 2 +- src/model.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 03a3b573c..2132c2a80 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -486,7 +486,7 @@ function replace_returns(e::Expr) if Meta.isexpr(e, :return) # NOTE: `return` always has an argument. In the case of - # `return`, the parsed expression will be `return nothing`. + # an empty `return`, the lowered expression will be `return nothing`. # Hence we don't need any special handling for empty returns. retval_expr = if length(e.args) > 1 Expr(:tuple, e.args...) diff --git a/src/model.jl b/src/model.jl index e13e9f020..fd9029ae7 100644 --- a/src/model.jl +++ b/src/model.jl @@ -432,7 +432,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ function evaluate_threadunsafe(model, varinfo, context) - resetlogp!!(varinfo) + varinfo = resetlogp!!(varinfo) return _evaluate(model, varinfo, context) end @@ -448,10 +448,10 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ function evaluate_threadsafe(model, varinfo, context) - resetlogp!!(varinfo) + varinfo = resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) result = _evaluate(model, wrapper, context) - setlogp!!(varinfo, getlogp(wrapper)) + varinfo = setlogp!!(varinfo, getlogp(wrapper)) return result end From f601d5da582a7a84a17e0d59ece7611cfb1823b8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 9 Sep 2021 16:10:19 +0100 Subject: [PATCH 129/216] updated tests --- test/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 0446008b3..96e247e1b 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -114,10 +114,10 @@ # del is set by default @test !is_flagged(vi, vn_x, "del") - set_flag!!(vi, vn_x, "del") + set_flag!(vi, vn_x, "del") @test is_flagged(vi, vn_x, "del") - unset_flag!!(vi, vn_x, "del") + unset_flag!(vi, vn_x, "del") @test !is_flagged(vi, vn_x, "del") end vi = VarInfo() From 72d2f53b940c3f0c9a9217288d49e45c09bad3ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:26:02 +0000 Subject: [PATCH 130/216] removed unnecessary BangBang methods --- src/DynamicPPL.jl | 8 ----- src/context_implementations.jl | 53 ++++++++++++++++++---------------- src/loglikelihoods.jl | 12 +++++--- src/threadsafe.jl | 8 ++--- src/utils.jl | 8 +++++ src/varinfo.jl | 20 ++++++------- 6 files changed, 58 insertions(+), 51 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 88a2e196a..1f6dc6601 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -58,17 +58,12 @@ export AbstractVarInfo, set_flag!, unset_flag!, set_flag!!, - unset_flag!!, setgid!, updategid!, - setgid!!, - updategid!!, setorder!, istrans, link!, invlink!, - link!!, - invlink!!, tonamedtuple, # VarName (reexport from AbstractPPL) VarName, @@ -173,7 +168,4 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) -@deprecate link!(vi, spl) link!!(vi, spl) -@deprecate invlink!(vi, spl) invlink!!(vi, spl) - end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 7f843b2cc..6d5c86816 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -57,7 +57,7 @@ function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) vi = setindex!!( vi, vectorize(right, get(context.vars, vn)), vn ) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) end @@ -66,7 +66,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end @@ -74,7 +74,7 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, vi) end @@ -88,7 +88,7 @@ function tilde_assume( ) if haskey(context.vars, getsym(vn)) vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) - vi = settrans!!(vi, false, vn) + settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end @@ -141,15 +141,17 @@ function tilde_observe(::IsParent, context::AbstractContext, args...) return tilde_observe(childcontext(context), args...) end -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::PriorContext, right, left, vi) = 0, vi +tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vi) + logp, vi = tilde_observe(context.context, right, left, vi) + return context.loglike_scalar * logp, vi end function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vi) + logp, vi = tilde_observe(context.context, sampler, right, left, vi) + return context.loglike_scalar * logp, vi end # `PrefixContext` @@ -180,7 +182,7 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) - logp = tilde_observe(context, right, left, vi) + logp, vi = tilde_observe(context, right, left, vi) return left, acclogp!!(vi, logp) end @@ -195,7 +197,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # SampleFromPrior and SampleFromUniform @@ -223,14 +225,14 @@ function assume( settrans!(vi, false, vn) end - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) observe(sampler::AbstractSampler, right, left, vi) = observe(right, left, vi) function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(right, left) + return Distributions.loglikelihood(right, left), vi end # .~ functions @@ -372,10 +374,10 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) - value, logp = dot_tilde_assume(context, right, left, vn, vi) + value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) # Mutation of `value` no longer occurs in main body, so we do it here. left .= value - return value, acclogp!!(vi, logp) + return value, acclogp!!(vi, logp), vi end # `dot_assume` @@ -393,7 +395,7 @@ function dot_assume( lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return r, lp + return r, lp, vi end function dot_assume( @@ -407,7 +409,7 @@ function dot_assume( @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -424,7 +426,7 @@ function dot_assume( # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = reshape(vi[vec(vns)], size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -438,7 +440,7 @@ function dot_assume( r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) return error( @@ -559,12 +561,13 @@ function dot_tilde_observe(::IsParent, context::AbstractContext, args...) return dot_tilde_observe(childcontext(context), args...) end -dot_tilde_observe(::PriorContext, right, left, vi) = 0 -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(::PriorContext, right, left, vi) = 0, vi +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0, vi # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) + logp, vi = dot_tilde_observe(context.context, right, left, vi) + return context.loglike_scalar * logp, vi end # `PrefixContext` @@ -594,7 +597,7 @@ probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) - logp = dot_tilde_observe(context, right, left, vi) + logp, vi = dot_tilde_observe(context, right, left, vi) return left, acclogp!!(vi, logp) end @@ -604,13 +607,13 @@ function dot_observe(::AbstractSampler, dist, value, vi) end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(dist, value), vi end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dists, value) + return Distributions.loglikelihood(dists, value), vi end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) - return sum(Distributions.loglikelihood.(dists, value)) + return sum(Distributions.loglikelihood.(dists, value)), vi end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 15568b255..6d29049e3 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -74,7 +74,7 @@ end function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. - logp = tilde_observe(context.context, right, left, vi) + logp, vi = tilde_observe(context.context, right, left, vi) # Track loglikelihood value. push!(context, vn, logp) @@ -108,13 +108,17 @@ end # FIXME: This is really not a good approach since it needs to stay in sync with # the `dot_assume` implementations, but as things are _right now_ this is the best we can do. function _pointwise_tilde_observe(context, right, left, vi) - return tilde_observe.(Ref(context), right, left, Ref(vi)) + # We need to drop the `vi` returned. + observe_logps(r, l) = first(tilde_observe(context, r, l, vi)) + return observe_logps.(right, left) end function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi + context, right::MultivariateDistribution, left::AbstractMatrix, vi::VarInfo ) - return tilde_observe.(Ref(context), Ref(right), eachcol(left), Ref(vi)) + # We need to drop the `vi` returned. + observe_logps(l) = first(tilde_observe(context, right, l, vi)) + return observe_logps.(eachcol(left)) end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9c59fa507..efe0debc2 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -46,8 +46,8 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!!(vi.varinfo, gid, vn) +function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) + return setgid!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -55,8 +55,8 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!!(vi.varinfo, spl) -invlink!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!!(vi.varinfo, spl) +link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) +invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) diff --git a/src/utils.jl b/src/utils.jl index 537fcc90e..527c1220a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,3 +155,11 @@ end ####################### collectmaybe(x) = x collectmaybe(x::Base.AbstractSet) = collect(x) + +####################### +# BangBang.jl related # +####################### +function set!!(obj, vn::VarName{sym}, value) where {sym} + lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + return Setfield.set(obj, lens, value) +end diff --git a/src/varinfo.jl b/src/varinfo.jl index 7c4858c6b..e2e7dabcc 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -738,13 +738,13 @@ end # X -> R for all variables associated with given sampler """ - link!!(vi::VarInfo, spl::Sampler) + link!(vi::VarInfo, spl::Sampler) Transform the values of the random variables sampled by `spl` in `vi` from the support of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!!(vi::UntypedVarInfo, spl::Sampler) +function link!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -763,10 +763,10 @@ function link!!(vi::UntypedVarInfo, spl::Sampler) @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!!(vi::TypedVarInfo, spl::AbstractSampler) - return link!!(vi, spl, Val(getspace(spl))) +function link!(vi::TypedVarInfo, spl::AbstractSampler) + return link!(vi, spl, Val(getspace(spl))) end -function link!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -807,13 +807,13 @@ end # R -> X for all variables associated with given sampler """ - invlink!!(vi::VarInfo, spl::AbstractSampler) + invlink!(vi::VarInfo, spl::AbstractSampler) Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @@ -830,10 +830,10 @@ function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!!(vi::TypedVarInfo, spl::AbstractSampler) - return invlink!!(vi, spl, Val(getspace(spl))) +function invlink!(vi::TypedVarInfo, spl::AbstractSampler) + return invlink!(vi, spl, Val(getspace(spl))) end -function invlink!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end From df212b7a61e7f167908824c81930f7161c6b8569 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:26:18 +0000 Subject: [PATCH 131/216] fixed zygote rule for dot_observe --- src/compat/ad.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 4fd2830b3..edcac7874 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -16,7 +16,7 @@ ZygoteRules.@adjoint function dot_observe( ) function dot_observe_fallback(spl, dists, value, vi) increment_num_produce!(vi) - return sum(map(Distributions.loglikelihood, dists, value)) + return sum(map(Distributions.loglikelihood, dists, value)), vi end return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi) end From 74bfd4ec71c375997262d4e3ae6c3c5164222d54 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:26:31 +0000 Subject: [PATCH 132/216] fixed Setfield.jl + returning VarInfo bug in model-macro --- src/compiler.jl | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6827d8888..6f4c6fc5f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -401,20 +401,26 @@ function generate_tilde(left, right) end function generate_tilde_assume(left, right, vn) - expr = :( - $left, __varinfo__ = $(DynamicPPL.tilde_assume!!)( + # HACK: Because the Setfield.jl macro does not support assignment + # with multiple arguments on the LHS, we need to capture the return-values + # and then update them LHS variables one by one. + tmp = gensym(:tmp) + expr = :($left = ($tmp)[1]) + if left isa Expr + expr = AbstractPPL.drop_escape( + Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + ) + end + + return quote + $tmp = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) - ) - - return if left isa Expr - AbstractPPL.drop_escape( - Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) - ) - else - return expr + $expr + __varinfo__ = $tmp[2] + $left, __varinfo__ end end @@ -456,7 +462,7 @@ function generate_dot_tilde_assume(left, right, vn) # `.=` is always going to be inplace + needs `left` to # be something that supports `.=`. return :( - $left, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( + ($left, __varinfo__) = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn From b7595356b930ba644597e88eade42e1024d123e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:27:44 +0000 Subject: [PATCH 133/216] updated tests --- src/simple_varinfo.jl | 48 ++++++++++++------------------------------- src/test_utils.jl | 2 +- test/Project.toml | 2 ++ test/compat/ad.jl | 3 ++- test/varinfo.jl | 4 ++-- 5 files changed, 20 insertions(+), 39 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f7fa27578..df03d9552 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -16,7 +16,7 @@ The major differences between this and `TypedVarInfo` are: # Examples ```jldoctest; setup=:(using Distributions) -julia> using StableRNGs +julia> using StableRNGs, OrderedCollections julia> @model function demo() m ~ Normal() @@ -52,7 +52,7 @@ julia> # We can also access arbitrary varnames pointing to `x`, e.g. 1.3736306979834252 julia> DynamicPPL.getval(vi, @varname(x[1:2])) -2-element view(::Vector{Float64}, 1:2) with eltype Float64: +2-element Vector{Float64}: 0.4471218424633827 1.3736306979834252 @@ -61,9 +61,9 @@ julia> # (×) If we don't provide the container... ERROR: type NamedTuple has no field x [...] -julia> # If one does not know the varnames, we can use a `Dict` instead. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); vi -SimpleVarInfo{Dict{Any, Any}, Float64}(Dict{Any, Any}(x[1] => -1.019202452456547, x[2] => -0.7935128416361353, m => 0.683947930996541), -3.8249261202386906) +julia> # If one does not know the varnames, we can use a `OrderedDict` instead. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); vi +SimpleVarInfo{OrderedDict{Any, Any}, Float64}(OrderedCollections.OrderedDict{Any, Any}(m => 0.683947930996541, x[1] => -1.019202452456547, x[2] => -0.7935128416361353), -3.8249261202386906) julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) @@ -121,33 +121,15 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end -# TODO: Get rid of this once we have lenses. -_getindex_view(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) -_getindex_view(x, inds::Tuple{}) = x - -# TODO: Get rid of this once we have lenses. -function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym,Tuple{}}) where {sym} - return merge(nt, NamedTuple{(sym,)}((val,))) -end -function _setvalue!!(nt::NamedTuple, val, vn::VarName{sym}) where {sym} - # Use `getproperty` instead of `getfield` - value = getproperty(nt, sym) - # Note that this will return a `view`, even if the resulting value is 0-dim. - # This makes it possible to call `setindex!` on the result later to update - # in place even in the case where are retrieving a single element, e.g. `x[1]`. - dest_view = _getindex_view(value, vn.indexing) - dest_view .= val - - return nt -end - # `NamedTuple` function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) - return _getvalue(vi.θ, vn) + return get(vi.θ, vn) end # `Dict` -function getval(vi::SimpleVarInfo{<:Dict}, vn::VarName) +function getval(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName{sym}) where {sym} + # TODO: Should we maybe allow indexing of sub-keys, etc. too? E.g. + # if `x` is present and it has an array, maybe we should allow indexing `x[1]`, etc. return vi.θ[vn] end @@ -180,12 +162,12 @@ end # `NamedTuple` function push!!( vi::SimpleVarInfo{<:NamedTuple}, - vn::VarName{sym,Tuple{}}, + vn::VarName{sym,Setfield.IdentityLens}, value, dist::Distribution, gidset::Set{Selector}, ) where {sym} - Setfield.@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) + return Setfield.@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) end function push!!( vi::SimpleVarInfo{<:NamedTuple}, @@ -194,16 +176,12 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - # We update in place. - # We need a view into the array, hence we call `_getvalue` directly - # rather than `getval`. - _setvalue!!(vi.θ, value, vn) - return vi + return Setfield.@set vi.θ = set!!(vi.θ, vn, value) end # `Dict` function push!!( - vi::SimpleVarInfo{<:Dict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) vi.θ[vn] = r return vi diff --git a/src/test_utils.jl b/src/test_utils.jl index d4b5c7206..db2193a73 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -98,7 +98,7 @@ end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - m = @submodel _prior_dot_assume() + @submodel m = _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end diff --git a/test/Project.toml b/test/Project.toml index 3af6ef22d..a322ade28 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -29,6 +30,7 @@ Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" MCMCChains = "4.0.4, 5" MacroTools = "0.5.5" +OrderedCollections = "1" Setfield = "0.7.1" StableRNGs = "1" Tracker = "0.2.11" diff --git a/test/compat/ad.jl b/test/compat/ad.jl index 3a8058ca9..f76ce6f6e 100644 --- a/test/compat/ad.jl +++ b/test/compat/ad.jl @@ -30,9 +30,10 @@ # https://github.com/TuringLang/Turing.jl/issues/1595 @testset "dot_observe" begin function f_dot_observe(x) - return DynamicPPL.dot_observe( + logp, _ = DynamicPPL.dot_observe( SampleFromPrior(), [Normal(), Normal(-1.0, 2.0)], x, VarInfo() ) + return logp end function f_dot_observe_manual(x) return logpdf(Normal(), x[1]) + logpdf(Normal(-1.0, 2.0), x[2]) diff --git a/test/varinfo.jl b/test/varinfo.jl index 96e247e1b..2f7816024 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -135,14 +135,14 @@ push!!(vi, vn, r, dist, gid1) @test meta.gids[meta.idcs[vn]] == Set([gid1]) - setgid!!(vi, gid2, vn) + setgid!(vi, gid2, vn) @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) vi = empty!!(TypedVarInfo(vi)) meta = vi.metadata push!!(vi, vn, r, dist, gid1) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) - setgid!!(vi, gid2, vn) + setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end @testset "setval! & setval_and_resample!" begin From a61eef26a5b8e9fcfd3a2936ee42ade337a5a047 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:49:35 +0000 Subject: [PATCH 134/216] fixed docs --- docs/Project.toml | 2 ++ src/submodel_macro.jl | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 83ce62d5e..e6d89dadd 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,11 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] Distributions = "0.25" Documenter = "0.27" +OrderedCollections = "1" StableRNGs = "1" diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 010bacb5f..cd93131e3 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -67,8 +67,8 @@ julia> @model function demo1(x) end; julia> @model function demo2(x, y, z) - a = @submodel sub1 demo1(x) - b = @submodel sub2 demo1(y) + @submodel sub1 a = demo1(x) + @submodel sub2 b = demo1(y) return z ~ Uniform(-a, b) end; ``` From 8bc72ff8309ca2ef432bd9382da10eced6902e1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 12:53:35 +0000 Subject: [PATCH 135/216] formatting --- src/context_implementations.jl | 4 +--- src/simple_varinfo.jl | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6d5c86816..fc6d76871 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -54,9 +54,7 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi = setindex!!( - vi, vectorize(right, get(context.vars, vn)), vn - ) + vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn) settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, vi) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index df03d9552..0c739f0d8 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -181,7 +181,11 @@ end # `Dict` function push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, r, dist::Distribution, gidset::Set{Selector} + vi::SimpleVarInfo{<:AbstractDict}, + vn::VarName, + r, + dist::Distribution, + gidset::Set{Selector}, ) vi.θ[vn] = r return vi From cdd88d0cabea41d20edcaac9462a709fff371845 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 6 Nov 2021 23:14:37 +0000 Subject: [PATCH 136/216] fixed issues when using ThreadSafeVarInfo --- src/context_implementations.jl | 12 ++++++------ src/threadsafe.jl | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index fc6d76871..c64d708d9 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -204,7 +204,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi::VarInfo, + vi::AbstractVarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. @@ -433,7 +433,7 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, var::AbstractArray, - vi::VarInfo, + vi::AbstractVarInfo, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions @@ -448,7 +448,7 @@ end function get_and_set_val!( rng, - vi::VarInfo, + vi::AbstractVarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -481,7 +481,7 @@ end function get_and_set_val!( rng, - vi::VarInfo, + vi::AbstractVarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -512,7 +512,7 @@ function get_and_set_val!( end function set_val!( - vi::VarInfo, + vi::AbstractVarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix, @@ -524,7 +524,7 @@ function set_val!( return val end function set_val!( - vi::VarInfo, + vi::AbstractVarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, val::AbstractArray, diff --git a/src/threadsafe.jl b/src/threadsafe.jl index efe0debc2..5d48150b0 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -92,8 +92,8 @@ function push!!( return push!!(vi.varinfo, vn, r, dist, gidset) end -function unset_flag!!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!!(vi.varinfo, vn, flag) +function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) + return unset_flag!(vi.varinfo, vn, flag) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) From 3faa883d237e219cc6c814ef260b81554507aa3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 7 Nov 2021 08:46:40 +0000 Subject: [PATCH 137/216] fixed _pointwise_observe for ThreadSafeVarInfo --- src/loglikelihoods.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6d29049e3..421292b05 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -114,7 +114,7 @@ function _pointwise_tilde_observe(context, right, left, vi) end function _pointwise_tilde_observe( - context, right::MultivariateDistribution, left::AbstractMatrix, vi::VarInfo + context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo ) # We need to drop the `vi` returned. observe_logps(l) = first(tilde_observe(context, right, l, vi)) From e14930ee819b3c6bb2f9a6e1dbd7eed6a359c394 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 7 Nov 2021 08:47:14 +0000 Subject: [PATCH 138/216] updated ThreadSafeVarInfo --- src/model.jl | 6 +++--- src/threadsafe.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/model.jl b/src/model.jl index a8aca4eef..7006978ff 100644 --- a/src/model.jl +++ b/src/model.jl @@ -450,9 +450,9 @@ See also: [`evaluate_threadunsafe`](@ref) function evaluate_threadsafe(model, varinfo, context) varinfo = resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, context) - varinfo = setlogp!!(varinfo, getlogp(wrapper)) - return result + result, wrapper_new = _evaluate(model, wrapper, context) + varinfo = setlogp!!(varinfo, getlogp(wrapper_new)) + return result, varinfo end """ diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 5d48150b0..2ecfaf07d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -30,13 +30,13 @@ function resetlogp!!(vi::ThreadSafeVarInfo) for x in vi.logps x[] = zero(x[]) end - return resetlogp!!(vi.varinfo) + return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps) end function setlogp!!(vi::ThreadSafeVarInfo, logp) for x in vi.logps x[] = zero(x[]) end - return setlogp!!(vi.varinfo, logp) + return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), vi.logps) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) From fc9378289691037c8d3abd41ef2eb23b1a198f5b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 7 Nov 2021 08:47:23 +0000 Subject: [PATCH 139/216] made SimpleVarInfo compat with ThreadSafeVarInfo and added show --- src/DynamicPPL.jl | 2 +- src/simple_varinfo.jl | 46 ++++++++++++++++++++----------------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 4284c60ce..3ca718cec 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -148,8 +148,8 @@ include("varname.jl") include("distribution_wrappers.jl") include("contexts.jl") include("varinfo.jl") -include("simple_varinfo.jl") include("threadsafe.jl") +include("simple_varinfo.jl") include("context_implementations.jl") include("compiler.jl") include("prob_macro.jl") diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0c739f0d8..7b346cd35 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -16,7 +16,9 @@ The major differences between this and `TypedVarInfo` are: # Examples ```jldoctest; setup=:(using Distributions) -julia> using StableRNGs, OrderedCollections +julia> using StableRNGs + +julia> using OrderedCollections: OrderedDict # ensures consisent output julia> @model function demo() m ~ Normal() @@ -39,7 +41,7 @@ julia> # In the `NamedTuple` version we need to provide the place-holder values # the variablse which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo{NamedTuple{(:x, :m), Tuple{Vector{Float64}, Float64}}, Float64}((x = [0.4471218424633827, 1.3736306979834252], m = -0.6702516921145671), -4.024823883230379) +SimpleVarInfo((x = [0.4471218424633827, 1.3736306979834252],), -4.024823883230379) julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) @@ -63,7 +65,7 @@ ERROR: type NamedTuple has no field x julia> # If one does not know the varnames, we can use a `OrderedDict` instead. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); vi -SimpleVarInfo{OrderedDict{Any, Any}, Float64}(OrderedCollections.OrderedDict{Any, Any}(m => 0.683947930996541, x[1] => -1.019202452456547, x[2] => -0.7935128416361353), -3.8249261202386906) +SimpleVarInfo(OrderedCollections.OrderedDict{Any, Any}(m => 0.683947930996541, x[1] => -1.019202452456547, x[2] => -0.7935128416361353), -3.8249261202386906) julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) @@ -121,6 +123,14 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) return vi end +function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) + print(io, "SimpleVarInfo(") + print(io, svi.θ) + print(io, ", ") + print(io, svi.logp) + print(io, ")") +end + # `NamedTuple` function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) return get(vi.θ, vn) @@ -191,13 +201,10 @@ function push!!( return vi end -# Context implementations -function tilde_assume!!(context, right, vn, inds, vi::SimpleVarInfo) - value, logp, vi_new = tilde_assume(context, right, vn, inds, vi) - return value, acclogp!!(vi_new, logp) -end +const SimpleOrThreadSafeSimple{T} = Union{SimpleVarInfo{T},ThreadSafeVarInfo{<:SimpleVarInfo{T}}} -function assume(dist::Distribution, vn::VarName, vi::SimpleVarInfo) +# Context implementations +function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) left = vi[vn] return left, Distributions.loglikelihood(dist, left), vi end @@ -207,29 +214,18 @@ function assume( sampler::SampleFromPrior, dist::Distribution, vn::VarName, - vi::SimpleVarInfo, + vi::SimpleOrThreadSafeSimple, ) value = init(rng, dist, sampler) vi = push!!(vi, vn, value, dist, sampler) return value, Distributions.loglikelihood(dist, value), vi end -# function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) -# throw(MethodError(dot_tilde_assume!!, (context, right, left, vn, inds, vi))) -# end - -function dot_tilde_assume!!(context, right, left, vn, inds, vi::SimpleVarInfo) - value, logp, vi_new = dot_tilde_assume(context, right, left, vn, inds, vi) - # Mutation of `value` no longer occurs in main body, so we do it here. - left .= value - return value, acclogp!!(vi_new, logp) -end - function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, - vi::SimpleVarInfo, + vi::SimpleOrThreadSafeSimple, ) @assert length(dist) == size(var, 1) # NOTE: We cannot work with `var` here because we might have a model of the form @@ -249,7 +245,7 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, - vi::SimpleVarInfo{<:NamedTuple}, + vi::SimpleOrThreadSafeSimple, ) # NOTE: We cannot work with `var` here because we might have a model of the form # @@ -263,8 +259,8 @@ function dot_assume( end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. -increment_num_produce!(::SimpleVarInfo) = nothing -settrans!(vi::SimpleVarInfo, trans::Bool, vn::VarName) = nothing +increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing +settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.θ)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.θ)) From b6939f67ad07179ae2784b5c9f5becd4e7e9782f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 7 Nov 2021 08:47:40 +0000 Subject: [PATCH 140/216] added some tests for return-values of models --- test/model.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/test/model.jl b/test/model.jl index 2cdeae5fa..111fa4645 100644 --- a/test/model.jl +++ b/test/model.jl @@ -65,4 +65,20 @@ @test nameof(test1(rand())) == :test1 @test nameof(test2(rand())) == :test2 end + + @testset "Internal methods" begin + model = gdemo_default + + # sample from model and extract variables + vi = VarInfo(model) + + # Second component of return-value of `evaluate` should + # be a `DynamicPPL.AbstractVarInfo`. + evaluate_retval = DynamicPPL.evaluate(model, vi, DefaultContext()) + @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo + + # Should not return `AbstractVarInfo` when we call the model. + call_retval = model() + @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) + end end From 8903523440dc6c1a6d4701f2641ded9274f538a4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 7 Nov 2021 08:51:49 +0000 Subject: [PATCH 141/216] formatting --- src/compiler.jl | 2 +- src/simple_varinfo.jl | 6 ++++-- test/contexts.jl | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 37894a136..6f4c6fc5f 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,7 @@ function isassumption(expr::Union{Symbol,Expr}) # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) true else $(maybe_view(expr)) === missing diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 7b346cd35..d9201fcde 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -128,7 +128,7 @@ function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) print(io, svi.θ) print(io, ", ") print(io, svi.logp) - print(io, ")") + return print(io, ")") end # `NamedTuple` @@ -201,7 +201,9 @@ function push!!( return vi end -const SimpleOrThreadSafeSimple{T} = Union{SimpleVarInfo{T},ThreadSafeVarInfo{<:SimpleVarInfo{T}}} +const SimpleOrThreadSafeSimple{T} = Union{ + SimpleVarInfo{T},ThreadSafeVarInfo{<:SimpleVarInfo{T}} +} # Context implementations function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) diff --git a/test/contexts.jl b/test/contexts.jl index f3a1ae800..edf581d4d 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -222,7 +222,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - get(val, getlens(vn_child)) + get(val, getlens(vn_child)) end end end From d90a6cdd589d3ece6b2173aed3b11af78d220008 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 04:11:40 +0000 Subject: [PATCH 142/216] fixed doctest for SimpleVarInfo --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d9201fcde..2cf33ec7f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -41,7 +41,7 @@ julia> # In the `NamedTuple` version we need to provide the place-holder values # the variablse which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo((x = [0.4471218424633827, 1.3736306979834252],), -4.024823883230379) +SimpleVarInfo((x = [0.4471218424633827, 1.3736306979834252], m = -0.6702516921145671), -4.024823883230379) julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) From 082b2ef6c9490630b68a342e1806f6fde85546d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 04:14:26 +0000 Subject: [PATCH 143/216] formatting --- src/compiler.jl | 2 +- test/contexts.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6f4c6fc5f..37894a136 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,7 @@ function isassumption(expr::Union{Symbol,Expr}) # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. if !($(DynamicPPL.inargnames)($vn, __model__)) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) true else $(maybe_view(expr)) === missing diff --git a/test/contexts.jl b/test/contexts.jl index edf581d4d..f3a1ae800 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -222,7 +222,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - get(val, getlens(vn_child)) + get(val, getlens(vn_child)) end end end From 9ad362d77003b021dbaf69e067132032868d8fe2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 12:57:12 +0000 Subject: [PATCH 144/216] removed comparison of show from doctest for SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++-------- test/Project.toml | 2 -- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2cf33ec7f..b6b20db24 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -18,8 +18,6 @@ The major differences between this and `TypedVarInfo` are: ```jldoctest; setup=:(using Distributions) julia> using StableRNGs -julia> using OrderedCollections: OrderedDict # ensures consisent output - julia> @model function demo() m ~ Normal() x = Vector{Float64}(undef, 2) @@ -38,10 +36,9 @@ julia> ### Sampling ### ctx = SamplingContext(rng, SampleFromPrior(), DefaultContext()); julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variablse which are using "containers", e.g. `Array`. + # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); vi -SimpleVarInfo((x = [0.4471218424633827, 1.3736306979834252], m = -0.6702516921145671), -4.024823883230379) + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) @@ -63,9 +60,8 @@ julia> # (×) If we don't provide the container... ERROR: type NamedTuple has no field x [...] -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(OrderedDict()), ctx); vi -SimpleVarInfo(OrderedCollections.OrderedDict{Any, Any}(m => 0.683947930996541, x[1] => -1.019202452456547, x[2] => -0.7935128416361353), -3.8249261202386906) +julia> # If one does not know the varnames, we can use a `Dict` instead. + _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) diff --git a/test/Project.toml b/test/Project.toml index 7d9831fa6..6c0dc994c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -10,7 +10,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -30,7 +29,6 @@ Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" MCMCChains = "4.0.4, 5" MacroTools = "0.5.5" -OrderedCollections = "1" Setfield = "0.7.1, 0.8" StableRNGs = "1" Tracker = "0.2.11" From 71095961844288f7d5759b92dc857e7c54120915 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 18:23:24 +0000 Subject: [PATCH 145/216] Update src/compiler.jl Co-authored-by: David Widmann --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 37894a136..fc83764cd 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -540,7 +540,7 @@ Suppose the following is the return-value: return x ~ Normal() ``` -Without `return_values`, once expanded in [`generated_mainbody!`](@ref), this would be +Without `return_values`, once expanded in `generated_mainbody!`, this would be ```julia return (x, __varinfo__ = tilde_assume!!(...)), __varinfo__ From d5345d9b2bc178e28c0c4a7c2dab27a928999e00 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 23:23:45 +0000 Subject: [PATCH 146/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fc83764cd..fe7810b05 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -405,7 +405,7 @@ function generate_tilde_assume(left, right, vn) # with multiple arguments on the LHS, we need to capture the return-values # and then update them LHS variables one by one. tmp = gensym(:tmp) - expr = :($left = ($tmp)[1]) + expr = :($left = $first($tmp)) if left isa Expr expr = AbstractPPL.drop_escape( Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) @@ -419,7 +419,7 @@ function generate_tilde_assume(left, right, vn) __varinfo__, ) $expr - __varinfo__ = $tmp[2] + __varinfo__ = $last($tmp) $left, __varinfo__ end end From 42bf1c50171c979cf6495ede03b84a51e3592238 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 8 Nov 2021 23:24:10 +0000 Subject: [PATCH 147/216] removed OrderedCollections from docs --- docs/Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index e6d89dadd..83ce62d5e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,11 +1,9 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] Distributions = "0.25" Documenter = "0.27" -OrderedCollections = "1" StableRNGs = "1" From e14506fbb325ad156287d9b997664b4e61c31a48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 9 Nov 2021 13:40:53 +0000 Subject: [PATCH 148/216] some additional fixes --- src/DynamicPPL.jl | 4 ++++ src/sampler.jl | 16 +++++++++++----- src/simple_varinfo.jl | 4 ++-- src/varinfo.jl | 6 ++++++ 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 3ca718cec..c739a859e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -177,4 +177,8 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) +@deprecate initialize_parameters!(vi, init_params, spl) initialize_parameters!!( + vi, init_params, spl +) + end # module diff --git a/src/sampler.jl b/src/sampler.jl index 664031233..ff1b588c7 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -116,7 +116,7 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) +function initialize_parameters!!(vi::AbstractVarInfo, init_params, spl::Sampler) @debug "Using passed-in initial variable values" init_params # Flatten parameters. @@ -126,7 +126,10 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) # Get all values. linked = islinked(vi, spl) - linked && invlink!(vi, spl) + if linked + # TODO: Make work with immutable `vi`. + invlink!(vi, spl) + end theta = vi[spl] length(theta) == length(init_theta) || error("Provided initial value doesn't match the dimension of the model") @@ -140,10 +143,13 @@ function initialize_parameters!(vi::AbstractVarInfo, init_params, spl::Sampler) end # Update in `vi`. - vi[spl] = theta - linked && link!(vi, spl) + vi = setindex!!(vi, theta, spl) + if linked + # TODO: Make work with immutable `vi`. + link!(vi, spl) + end - return nothing + return vi end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b6b20db24..b18cc8b28 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -90,7 +90,7 @@ SimpleVarInfo() = SimpleVarInfo{Float64}() # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} - _, svi = DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...) + svi = last(DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...)) return svi end @@ -102,7 +102,7 @@ function SimpleVarInfo{T}( vi::VarInfo{<:NamedTuple{names}}, ::Type{D} ) where {T<:Real,names,D} values = values_as(vi, D) - return SimpleVarInfo{T}(values) + return SimpleVarInfo(values, convert(T, getlogp(vi))) end getlogp(vi::SimpleVarInfo) = vi.logp diff --git a/src/varinfo.jl b/src/varinfo.jl index e2e7dabcc..fe88f9c1a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -979,8 +979,14 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` ranges = _getranges(vi, spl) _setindex!(vi.metadata, val, ranges) + return nothing +end + +function setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) + setindex!(vi, val, spl) return vi end + # Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. @generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} expr = Expr(:block) From 715ef892c50fd1fbaa31b33b0863d09006ff5f80 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 9 Nov 2021 15:05:39 +0000 Subject: [PATCH 149/216] fixed method ambiguity and some ill-defined map --- src/context_implementations.jl | 9 ++++++--- src/simple_varinfo.jl | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index c64d708d9..bda19dd17 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -380,7 +380,10 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) # NOTE: We cannot work with `var` here because we might have a model of the form @@ -390,7 +393,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = vi[vns] - lp = sum(zip(vns, eachcol(r))) do vn, ri + lp = sum(zip(vns, eachcol(r))) do (vn, ri) return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end return r, lp, vi @@ -402,7 +405,7 @@ function dot_assume( dist::MultivariateDistribution, vns::AbstractVector{<:VarName}, var::AbstractMatrix, - vi, + vi::AbstractVarInfo, ) @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index b18cc8b28..3f9a26ac0 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -233,7 +233,7 @@ function dot_assume( # # in which case `var` will have `undef` elements, even if `m` is present in `vi`. value = vi[vns] - lp = sum(zip(vns, eachcol(value))) do vn, val + lp = sum(zip(vns, eachcol(value))) do (vn, val) return Distributions.logpdf(dist, val) end return value, lp, vi From 2991db282f06f8734f5e8dc97b3e1d773051a2b5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 11 Nov 2021 01:25:00 +0000 Subject: [PATCH 150/216] renamed evaluate to evaluate!! --- src/model.jl | 56 ++++++++++++++++++++++--------------------- src/prob_macro.jl | 2 +- src/sampler.jl | 2 +- src/simple_varinfo.jl | 8 +++---- src/submodel_macro.jl | 4 ++-- test/model.jl | 4 ++-- test/threadsafe.jl | 12 +++++----- 7 files changed, 45 insertions(+), 43 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7006978ff..606581d51 100644 --- a/src/model.jl +++ b/src/model.jl @@ -374,10 +374,10 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -(model::Model)(args...) = (first ∘ evaluate)(model, args...) +(model::Model)(args...) = (first ∘ evaluate!!)(model, args...) """ - evaluate(model::Model[, rng, varinfo, sampler, context]) + evaluate!!(model::Model[, rng, varinfo, sampler, context]) Sample from the `model` using the `sampler` with random number generator `rng` and the `context`, and store the sample and log joint probability in `varinfo`. @@ -387,57 +387,59 @@ Returns both the return-value of the original model, and the resulting varinfo. The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function evaluate(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) +function evaluate!!(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, context) + return evaluate_threadunsafe!!(model, varinfo, context) else - return evaluate_threadsafe(model, varinfo, context) + return evaluate_threadsafe!!(model, varinfo, context) end end -function evaluate( +function evaluate!!( model::Model, rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return evaluate(model, varinfo, SamplingContext(rng, sampler, context)) + return evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)) end -evaluate(model::Model, context::AbstractContext) = evaluate(model, VarInfo(), context) +evaluate!!(model::Model, context::AbstractContext) = evaluate!!(model, VarInfo(), context) -function evaluate(model::Model, args...) - return evaluate(model, Random.GLOBAL_RNG, args...) +function evaluate!!(model::Model, args...) + return evaluate!!(model, Random.GLOBAL_RNG, args...) end # without VarInfo -function evaluate(model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...) - return evaluate(model, rng, VarInfo(), sampler, args...) +function evaluate!!( + model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args... +) + return evaluate!!(model, rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function evaluate(model::Model, rng::Random.AbstractRNG, context::AbstractContext) - return evaluate(model, rng, VarInfo(), SampleFromPrior(), context) +function evaluate!!(model::Model, rng::Random.AbstractRNG, context::AbstractContext) + return evaluate!!(model, rng, VarInfo(), SampleFromPrior(), context) end """ - evaluate_threadunsafe(model, varinfo, context) + evaluate_threadunsafe!!(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. If the `model` makes use of Julia's multithreading this will lead to undefined behaviour. This method is not exposed and supposed to be used only internally in DynamicPPL. -See also: [`evaluate_threadsafe`](@ref) +See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe(model, varinfo, context) +function evaluate_threadunsafe!!(model, varinfo, context) varinfo = resetlogp!!(varinfo) - return _evaluate(model, varinfo, context) + return _evaluate!!(model, varinfo, context) end """ - evaluate_threadsafe(model, varinfo, context) + evaluate_threadsafe!!(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -445,22 +447,22 @@ With the wrapper, Julia's multithreading can be used for observe statements in t but parallel sampling will lead to undefined behaviour. This method is not exposed and supposed to be used only internally in DynamicPPL. -See also: [`evaluate_threadunsafe`](@ref) +See also: [`evaluate_threadunsafe!!`](@ref) """ -function evaluate_threadsafe(model, varinfo, context) +function evaluate_threadsafe!!(model, varinfo, context) varinfo = resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result, wrapper_new = _evaluate(model, wrapper, context) + result, wrapper_new = _evaluate!!(model, wrapper, context) varinfo = setlogp!!(varinfo, getlogp(wrapper_new)) return result, varinfo end """ - _evaluate(model::Model, varinfo, context) + _evaluate!!(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ -@generated function _evaluate( +@generated function _evaluate!!( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [ @@ -510,7 +512,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate(model, varinfo, DefaultContext()) + _, varinfo_new = evaluate!!(model, varinfo, DefaultContext()) return getlogp(varinfo_new) end @@ -522,7 +524,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate(model, varinfo, PriorContext()) + _, varinfo_new = evaluate!!(model, varinfo, PriorContext()) return getlogp(varinfo_new) end @@ -534,7 +536,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate(model, varinfo, LikelihoodContext()) + _, varinfo_new = evaluate!!(model, varinfo, LikelihoodContext()) return getlogp(varinfo_new) end diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 84497aef0..1eca69d43 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,7 +146,7 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - _, vi = DynamicPPL.evaluate(model, vi, SampleFromPrior(), PriorContext(left)) + _, vi = DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)) return getlogp(vi) end diff --git a/src/sampler.jl b/src/sampler.jl index ff1b588c7..bb3ab7633 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -82,7 +82,7 @@ function AbstractMCMC.step( # Update the parameters if provided. if haskey(kwargs, :init_params) - initialize_parameters!(vi, kwargs[:init_params], spl) + vi = initialize_parameters!!(vi, kwargs[:init_params], spl) # Update joint log probability. # TODO: fix properly by using sampler and evaluation contexts diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3f9a26ac0..928786f53 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -38,7 +38,7 @@ julia> ### Sampling ### julia> # In the `NamedTuple` version we need to provide the place-holder values for # the variables which are using "containers", e.g. `Array`. # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo((x = ones(2), )), ctx); + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); julia> # (✓) Vroom, vroom! FAST!!! DynamicPPL.getval(vi, @varname(x[1])) @@ -56,12 +56,12 @@ julia> DynamicPPL.getval(vi, @varname(x[1:2])) 1.3736306979834252 julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo(), ctx); vi + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi ERROR: type NamedTuple has no field x [...] julia> # If one does not know the varnames, we can use a `Dict` instead. - _, vi = DynamicPPL.evaluate(m, SimpleVarInfo{Float64}(Dict()), ctx); + _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx); julia> # (✓) Sort of fast, but only possible at runtime. DynamicPPL.getval(vi, @varname(x[1])) @@ -90,7 +90,7 @@ SimpleVarInfo() = SimpleVarInfo{Float64}() # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} - svi = last(DynamicPPL.evaluate(model, SimpleVarInfo{T}(), args...)) + svi = last(DynamicPPL.evaluate!!(model, SimpleVarInfo{T}(), args...)) return svi end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index cd93131e3..a399fc4c9 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -120,7 +120,7 @@ function submodel(expr, ctx=esc(:__context__)) return if args_assign === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate!!( $(esc(expr)), $(esc(:__varinfo__)), $(ctx) ) end @@ -129,7 +129,7 @@ function submodel(expr, ctx=esc(:__context__)) # TODO: Should we prefix by `L` by default? L, R = args_assign quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate( + $(esc(L)), $(esc(:__varinfo__)) = _evaluate!!( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) end diff --git a/test/model.jl b/test/model.jl index 111fa4645..466a7d1f4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -72,9 +72,9 @@ # sample from model and extract variables vi = VarInfo(model) - # Second component of return-value of `evaluate` should + # Second component of return-value of `evaluate!!` should # be a `DynamicPPL.AbstractVarInfo`. - evaluate_retval = DynamicPPL.evaluate(model, vi, DefaultContext()) + evaluate_retval = DynamicPPL.evaluate!!(model, vi, DefaultContext()) @test evaluate_retval[2] isa DynamicPPL.AbstractVarInfo # Should not return `AbstractVarInfo` when we call the model. diff --git a/test/threadsafe.jl b/test/threadsafe.jl index bd1f4f154..460d68ca3 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -60,7 +60,7 @@ @time wthreads(x)(vi) # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. - DynamicPPL.evaluate_threadsafe( + DynamicPPL.evaluate_threadsafe!!( wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -68,8 +68,8 @@ @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo - println(" evaluate_threadsafe:") - @time DynamicPPL.evaluate_threadsafe( + println(" evaluate_threadsafe!!:") + @time DynamicPPL.evaluate_threadsafe!!( wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -99,7 +99,7 @@ @test lp_w_threads ≈ lp_wo_threads # Ensure that we use `VarInfo`. - DynamicPPL.evaluate_threadunsafe( + DynamicPPL.evaluate_threadunsafe!!( wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), @@ -107,8 +107,8 @@ @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo - println(" evaluate_threadunsafe:") - @time DynamicPPL.evaluate_threadunsafe( + println(" evaluate_threadunsafe!!:") + @time DynamicPPL.evaluate_threadunsafe!!( wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), From 2fbbd5e19b07d2da51d23776d82cdee099050c30 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 14:44:18 +0000 Subject: [PATCH 151/216] added implementations of haskey, getindex and setindex!! for SimpleVarInfo --- src/simple_varinfo.jl | 161 +++++++++++++++++++++++++++++++++--------- src/utils.jl | 150 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 279 insertions(+), 32 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 928786f53..1b1458494 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -15,6 +15,7 @@ The major differences between this and `TypedVarInfo` are: b) the values have been specified with the corret shapes. # Examples +## General usage ```jldoctest; setup=:(using Distributions) julia> using StableRNGs @@ -76,16 +77,67 @@ julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) ERROR: KeyError: key x[1:2] not found [...] ``` + +## Indexing +Using `NamedTuple` as underlying storage. + +```jldoctest +julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )) +SimpleVarInfo((m = (a = [1.0],),), 0.0) + +julia> svi_nt[@varname(m)] +(a = [1.0],) + +julia> svi_nt[@varname(m.a)] +1-element Vector{Float64}: + 1.0 + +julia> svi_nt[@varname(m.a[1])] +1.0 + +julia> svi_nt[@varname(m.a[2])] +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> svi_nt[@varname(m.b)] +ERROR: type NamedTuple has no field b +[...] +``` + +Using `Dict` as underlying storage. +```jldoctest +julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))) +SimpleVarInfo(Dict{VarName{:m, Setfield.IdentityLens}, NamedTuple{(:a,), Tuple{Vector{Float64}}}}(m => (a = [1.0],)), 0.0) + +julia> svi_dict[@varname(m)] +(a = [1.0],) + +julia> svi_dict[@varname(m.a)] +1-element Vector{Float64}: + 1.0 + +julia> svi_dict[@varname(m.a[1])] +1.0 + +julia> svi_dict[@varname(m.a[2])] +ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] +[...] + +julia> svi_dict[@varname(m.b)] +ERROR: type NamedTuple has no field b +[...] +``` """ struct SimpleVarInfo{NT,T} <: AbstractVarInfo - θ::NT + values::NT logp::T end SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) -SimpleVarInfo(θ) = SimpleVarInfo{eltype(first(θ))}(θ) -SimpleVarInfo{T}() where {T<:Real} = SimpleVarInfo{T}(NamedTuple()) -SimpleVarInfo() = SimpleVarInfo{Float64}() +SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) +SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) +SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) +SimpleVarInfo(θ::NamedTuple) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) @@ -106,8 +158,18 @@ function SimpleVarInfo{T}( end getlogp(vi::SimpleVarInfo) = vi.logp -setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, logp) -acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.θ, getlogp(vi) + logp) +setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) +acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) + +""" + keys(vi::SimpleVarInfo) + +Return an iterator of keys present in `vi`. +""" +Base.keys(vi::SimpleVarInfo) = keys(vi.values) +# TODO: Is this really the "right" thing to do? +# Is there a better function name we can use? +Base.values(vi::SimpleVarInfo) = vi.values function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -121,42 +183,77 @@ end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) print(io, "SimpleVarInfo(") - print(io, svi.θ) + print(io, svi.values) print(io, ", ") print(io, svi.logp) return print(io, ")") end # `NamedTuple` -function getval(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) - return get(vi.θ, vn) +function getindex(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) + return get(vi.values, vn) end # `Dict` -function getval(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName{sym}) where {sym} - # TODO: Should we maybe allow indexing of sub-keys, etc. too? E.g. - # if `x` is present and it has an array, maybe we should allow indexing `x[1]`, etc. - return vi.θ[vn] +function getindex(vi::SimpleVarInfo, vn::VarName) + if haskey(vi.values, vn) + return vi.values[vn] + end + + # Split the lens into the key / `parent` and the + # extraction lens / `child`. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(vi.values, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # If we found a valid split, then we can extract the value. + # TODO: Should we also check that we `canview` the extracted `value`? + if issuccess + value = vi.values[VarName(vn, keylens)] + return get(value, child) + end + + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. -getval(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(vn -> getval(vi, vn), vns) -# To disambiguiate. -getval(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getval(vi, vn), vns) +function getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) + return map(vn -> getindex(vi, vn), vns) +end +# HACK: Needed to disambiguiate. +getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getindex(vi, vn), vns) -haskey(vi::SimpleVarInfo, vn) = haskey(vi.θ, getsym(vn)) +getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values +getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values +# TODO: Should we do better? +getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values -istrans(::SimpleVarInfo, vn::VarName) = false +haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) -getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.θ -getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.θ -# TODO: Should we do better? -getindex(vi::SimpleVarInfo, spl::Sampler) = vi.θ -getindex(vi::SimpleVarInfo, vn::VarName) = getval(vi, vn) -getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = getval(vi, vns) -# HACK: Need to disambiguiate. -getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = getval(vi, vns) +# TODO: Is `hasvalue` really the right function here? +function hasvalue(nt::NamedTuple, vn::VarName) + # LHS: Ensure that `nt` indeed has the property we want. + # RHS: Ensure that the lens can view into `nt`. + sym = getsym(vn) + return haskey(nt, sym) && canview(getlens(vn), getindex(nt, sym)) +end + +hasvalue(dictlike, vn::VarName) = haskey(dictlike, vn) || hasvalue(dictlike, parent(vn)) +hasvalue(dictlike, vn::VarName{<:Any, Setfield.IdentityLens}) = haskey(dictlike, vn) + +function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) + return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) +end +function setindex!!(vi::SimpleVarInfo, val, vn::VarName) + return SimpleVarInfo(setindex!!(vi.values, val, vn), vi.logp) +end + +istrans(::SimpleVarInfo, vn::VarName) = false # Necessary for `matchingvalue` to work properly. function Base.eltype( @@ -173,7 +270,7 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.θ = merge(vi.θ, NamedTuple{(sym,)}((value,))) + return Setfield.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end function push!!( vi::SimpleVarInfo{<:NamedTuple}, @@ -182,7 +279,7 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) where {sym} - return Setfield.@set vi.θ = set!!(vi.θ, vn, value) + return Setfield.@set vi.values = set!!(vi.values, vn, value) end # `Dict` @@ -193,7 +290,7 @@ function push!!( dist::Distribution, gidset::Set{Selector}, ) - vi.θ[vn] = r + vi.values[vn] = r return vi end @@ -260,6 +357,6 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing -values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.θ)) -values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.θ)) -values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.θ +values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) +values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) +values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values diff --git a/src/utils.jl b/src/utils.jl index 527c1220a..af5af0410 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -163,3 +163,153 @@ function set!!(obj, vn::VarName{sym}, value) where {sym} lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) return Setfield.set(obj, lens, value) end + +function set!!(obj, vn::VarName{sym}, value) where {sym} + lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) + return Setfield.set(obj, lens, value) +end + +############################# +# AbstractPPL.jl extensions # +############################# +# This is preferable to `haskey` because the order of arguments is different, and +# we're more likely to specialize on the key in these settings rather than the container. +# TODO: I'm not sure about this name. +""" + canview(lens, container) + +Return `true` if `lens` can be used to view `container`, and `false` otherwise. + +# Examples +```jldoctest; setup=:(using Setfield) +julia> canview(@lens(_.a), (a = 1.0, )) +true + +julia> canview(@lens(_.a), (b = 1.0, )) # property `a` does not exist +false + +julia> canview(@lens(_.a[1]), (a = [1.0, 2.0], )) +true + +julia> canview(@lens(_.a[3]), (a = [1.0, 2.0], )) # out of bounds +false +``` +""" +canview(lens, container) = false +canview(::Setfield.IdentityLens, _) = true +function canview(lens::Setfield.PropertyLens{field}, x) where {field} + return haskey(x, field) +end +# `IndexLens`: only relevant if `x` supports indexing. +canview(lens::Setfield.IndexLens, x) = false +canview(lens::Setfield.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) + +# `ComposedLens`: check that we can view `.outer` and `.inner`, but using +# value extracted using `.outer`. +function canview(lens::Setfield.ComposedLens, x) + return canview(lens.outer, x) && canview(lens.inner, get(x, lens.outer)) +end + +""" + parent(vn::VarName) + +Return the parent `VarName`. + +# Examples +```julia-repl +julia> parent(@varname(x.a[1])) +x.a + +julia> (parent ∘ parent)(@varname(x.a[1])) +x + +julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) +x +``` +""" +function parent(vn::VarName) + p = parent(getlens(vn)) + return p === nothing ? VarName(vn, Setfield.IdentityLens()) : VarName(vn, p) +end + +""" + parent(lens::Setfield.Lens) + +Return the parent lens. + +See also: [`parent_and_child`]. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: parent) +julia> parent(@lens(_.a[1])) +(@lens _.a) + +julia> (parent ∘ parent)(@lens(_.a[1])) +(@lens _) + +julia> # parent of `IdentityLens` is `IdentityLens` + (parent ∘ parent ∘ parent)(@lens(_.a[1])) +(@lens _) +``` +""" +parent(lens::Setfield.Lens) = first(parent_and_child(lens)) + +""" + parent(lens::Setfield.Lens) + +Return a 2-tuple of lenses `(parent, child)` where + +See also: [`parent`]. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: parent_and_child) +julia> parent_and_child(@lens(_.a[1])) +((@lens _.a), (@lens _[1])) + +julia> parent_and_child(@lens(_.a)) +(nothing, (@lens _.a)) +``` +""" +parent_and_child(lens::Setfield.Lens) = (nothing, lens) +function parent_and_child(lens::Setfield.ComposedLens) + p, child = parent_and_child(lens.inner) + parent = p === nothing ? lens.outer : lens.outer ∘ p + return parent, child +end + +""" + splitlens(condition, lens) + +Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, +`parent` is a lens such that `condition(parent)` is `true` and `parent ∘ child == lens`. + +If `issuccess` is `false`, then no such split could be found. + +# Examples +```jldoctest; setup=:(using Setfield; using DynamicPPL: splitlens) +julia> p, c, issucesss = splitlens(@lens(_.a[1])) do parent + # Succeeds! + parent == @lens(_.a) + end +((@lens _.a), (@lens _[1]), true) + +julia> p ∘ c +(@lens _.a[1]) + +julia> splitlens(@lens(_.a[1])) do parent + # Fails! + parent == @lens(_.b) + end +(nothing, (@lens _.a[1]), false) +``` +""" +function splitlens(condition, lens) + current_parent, current_child = parent_and_child(lens) + # We stop if either a) `condition` is satisfied, or b) we reached the root. + while !condition(current_parent) && current_parent !== nothing + current_parent, c = parent_and_child(current_parent) + current_child = c ∘ current_child + end + + return current_parent, current_child, condition(current_parent) +end From 016f4855fd79deb547cb7151ad816d8f0171c543 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 14:45:48 +0000 Subject: [PATCH 152/216] formatting --- src/simple_varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 1b1458494..cc4a76b27 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -211,13 +211,13 @@ function getindex(vi::SimpleVarInfo, vn::VarName) # If we found a valid split, then we can extract the value. # TODO: Should we also check that we `canview` the extracted `value`? - if issuccess - value = vi.values[VarName(vn, keylens)] - return get(value, child) + if !issuccess + # At this point we just throw an error since the key could not be found. + throw(KeyError(vn)) end - # At this point we just throw an error since the key could not be found. - throw(KeyError(vn)) + value = vi.values[VarName(vn, keylens)] + return get(value, child) end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than @@ -244,7 +244,7 @@ function hasvalue(nt::NamedTuple, vn::VarName) end hasvalue(dictlike, vn::VarName) = haskey(dictlike, vn) || hasvalue(dictlike, parent(vn)) -hasvalue(dictlike, vn::VarName{<:Any, Setfield.IdentityLens}) = haskey(dictlike, vn) +hasvalue(dictlike, vn::VarName{<:Any,Setfield.IdentityLens}) = haskey(dictlike, vn) function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) From 7a8d9a12d1dca1549a694b31ba23579d7d302be8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 14:49:22 +0000 Subject: [PATCH 153/216] dropped redundant definition --- src/simple_varinfo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index cc4a76b27..0fa245f23 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -137,7 +137,6 @@ SimpleVarInfo{T}(θ) where {T<:Real} = SimpleVarInfo{typeof(θ),T}(θ, zero(T)) SimpleVarInfo{T}(; kwargs...) where {T<:Real} = SimpleVarInfo{T}(NamedTuple(kwargs)) SimpleVarInfo(; kwargs...) = SimpleVarInfo{Float64}(NamedTuple(kwargs)) SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) -SimpleVarInfo(θ::NamedTuple) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) From b76b1d190feedcbb58322beefa95a5e56ee636f2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 14:50:36 +0000 Subject: [PATCH 154/216] use getproperty instead of getindex --- src/simple_varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0fa245f23..2e2033e8d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -239,7 +239,7 @@ function hasvalue(nt::NamedTuple, vn::VarName) # LHS: Ensure that `nt` indeed has the property we want. # RHS: Ensure that the lens can view into `nt`. sym = getsym(vn) - return haskey(nt, sym) && canview(getlens(vn), getindex(nt, sym)) + return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) end hasvalue(dictlike, vn::VarName) = haskey(dictlike, vn) || hasvalue(dictlike, parent(vn)) From bbdabd1f8bf97290eaab80b6dcc6eb627c4a0bf4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 14:58:49 +0000 Subject: [PATCH 155/216] fixed method-ambiguity and added some comments --- src/simple_varinfo.jl | 13 +++++++++++-- src/utils.jl | 5 ----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 2e2033e8d..d9530e9aa 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -242,13 +242,17 @@ function hasvalue(nt::NamedTuple, vn::VarName) return haskey(nt, sym) && canview(getlens(vn), getproperty(nt, sym)) end -hasvalue(dictlike, vn::VarName) = haskey(dictlike, vn) || hasvalue(dictlike, parent(vn)) -hasvalue(dictlike, vn::VarName{<:Any,Setfield.IdentityLens}) = haskey(dictlike, vn) +# For `dictlike` we need to check wether `vn` is "immediately" present, or +# if some ancestor of `vn` is present in `dictlike`. +hasvalue(dict::AbstractDict, vn::VarName) = haskey(dict, vn) || hasvalue(dict, parent(vn)) +hasvalue(dict::AbstractDict, vn::VarName{<:Any,Setfield.IdentityLens}) = haskey(dict, vn) function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) + # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) end function setindex!!(vi::SimpleVarInfo, val, vn::VarName) + # For dictlike objects, we treat the entire `vn` as a _key_ to set. return SimpleVarInfo(setindex!!(vi.values, val, vn), vi.logp) end @@ -356,6 +360,11 @@ end increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing +""" + values(varinfo, Type) + +Return the values/realizations in `varinfo` as `Type`, if implemented. +""" values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values diff --git a/src/utils.jl b/src/utils.jl index af5af0410..7abf153aa 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -164,11 +164,6 @@ function set!!(obj, vn::VarName{sym}, value) where {sym} return Setfield.set(obj, lens, value) end -function set!!(obj, vn::VarName{sym}, value) where {sym} - lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) - return Setfield.set(obj, lens, value) -end - ############################# # AbstractPPL.jl extensions # ############################# From 5a3ada521c9d666188e29d3fa190a09453288e9d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 15:04:22 +0000 Subject: [PATCH 156/216] fixed docstring of SimpleVarInfo --- src/simple_varinfo.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d9530e9aa..320e19bf2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -42,16 +42,16 @@ julia> # In the `NamedTuple` version we need to provide the place-holder values _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo((x = ones(2), )), ctx); julia> # (✓) Vroom, vroom! FAST!!! - DynamicPPL.getval(vi, @varname(x[1])) + vi[@varname(x[1])] 0.4471218424633827 julia> # We can also access arbitrary varnames pointing to `x`, e.g. - DynamicPPL.getval(vi, @varname(x)) + vi[@varname(x)] 2-element Vector{Float64}: 0.4471218424633827 1.3736306979834252 -julia> DynamicPPL.getval(vi, @varname(x[1:2])) +julia> vi[@varname(x[1:2])] 2-element Vector{Float64}: 0.4471218424633827 1.3736306979834252 @@ -65,15 +65,15 @@ julia> # If one does not know the varnames, we can use a `Dict` instead. _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx); julia> # (✓) Sort of fast, but only possible at runtime. - DynamicPPL.getval(vi, @varname(x[1])) + vi[@varname(x[1])] -1.019202452456547 julia> # In addtion, we can only access varnames as they appear in the model! - DynamicPPL.getval(vi, @varname(x)) + vi[@varname(x)] ERROR: KeyError: key x not found [...] -julia> julia> DynamicPPL.getval(vi, @varname(x[1:2])) +julia> vi[@varname(x[1:2])] ERROR: KeyError: key x[1:2] not found [...] ``` From ee6c11198cbbcb37698e2058c915f64ce12e6cf7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 15:45:19 +0000 Subject: [PATCH 157/216] fixed docstrings --- docs/Project.toml | 4 +++- src/utils.jl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 83ce62d5e..d53b58243 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,9 +1,11 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] Distributions = "0.25" Documenter = "0.27" -StableRNGs = "1" +Setfield = "0.7.1, 0.8" +StableRNGs = "1" \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 7abf153aa..9e65f3684 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -176,7 +176,9 @@ end Return `true` if `lens` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest; setup=:(using Setfield) +```jldoctest +julia> using Setfield + julia> canview(@lens(_.a), (a = 1.0, )) true From 7fa379d17a89f9f2b7a43cad4cf312852e6816fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 15:55:29 +0000 Subject: [PATCH 158/216] fixed Project.toml for docs --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index d53b58243..aa1315f41 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,4 +8,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Distributions = "0.25" Documenter = "0.27" Setfield = "0.7.1, 0.8" -StableRNGs = "1" \ No newline at end of file +StableRNGs = "1" From 0ea28b78a58c4bcc6d85a19780e26de69391b813 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 16:03:58 +0000 Subject: [PATCH 159/216] fixed docstring of canview --- src/utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 9e65f3684..bcbb5df93 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -176,9 +176,7 @@ end Return `true` if `lens` can be used to view `container`, and `false` otherwise. # Examples -```jldoctest -julia> using Setfield - +```jldoctest; setup=:(using Setfield; using DynamicPPL: canview) julia> canview(@lens(_.a), (a = 1.0, )) true From 59d61f627fd7127e59c88e6872a8877d144dafa8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 18:40:51 +0000 Subject: [PATCH 160/216] fixed docstrings --- src/utils.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index bcbb5df93..df615479b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -211,7 +211,7 @@ end Return the parent `VarName`. # Examples -```julia-repl +```julia-repl; setup=:(using DynamicPPL: parent) julia> parent(@varname(x.a[1])) x.a @@ -230,7 +230,8 @@ end """ parent(lens::Setfield.Lens) -Return the parent lens. +Return the parent lens. If `lens` doesn't have a parent, +`nothing` is returned. See also: [`parent_and_child`]. @@ -239,20 +240,19 @@ See also: [`parent_and_child`]. julia> parent(@lens(_.a[1])) (@lens _.a) -julia> (parent ∘ parent)(@lens(_.a[1])) -(@lens _) - -julia> # parent of `IdentityLens` is `IdentityLens` - (parent ∘ parent ∘ parent)(@lens(_.a[1])) -(@lens _) +julia> # Parent of lens without parents results in `nothing`. + (parent ∘ parent)(@lens(_.a[1])) === nothing ``` """ parent(lens::Setfield.Lens) = first(parent_and_child(lens)) """ - parent(lens::Setfield.Lens) + parent_and_child(lens::Setfield.Lens) + +Return a 2-tuple of lenses `(parent, child)` where `parent` is the +parent lens of `lens` and `child` is the child lens of `lens`. -Return a 2-tuple of lenses `(parent, child)` where +If `lens` does not have a parent, we return `(nothing, lens)`. See also: [`parent`]. From ab06c9cdf3088f0774b00e13ca392668c857d5e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 12 Nov 2021 18:43:53 +0000 Subject: [PATCH 161/216] another attempt at fixing docstrings --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index df615479b..9bf52efc1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -242,6 +242,7 @@ julia> parent(@lens(_.a[1])) julia> # Parent of lens without parents results in `nothing`. (parent ∘ parent)(@lens(_.a[1])) === nothing +true ``` """ parent(lens::Setfield.Lens) = first(parent_and_child(lens)) From b5e687642d7306f49ffe4b6124942696bd851953 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 00:01:24 +0000 Subject: [PATCH 162/216] added a TODO comment --- src/context_implementations.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index bda19dd17..ce2a61a67 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -508,6 +508,11 @@ function get_and_set_val!( else f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) + # TODO: This will inefficient since it will allocate an entire vector. + # We could either: + # 1. Figure out the broadcast size and use a `foreach`. + # 2. Define a anonynous function which returns `nothing`, which + # we then broadcast. This will allocate a vector of `nothing` though. push!!.(Ref(vi), vns, r, dists, Ref(spl)) settrans!.(Ref(vi), false, vns) end From a36c085221c178edf531a0c6beabc26d9612d533 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 00:05:22 +0000 Subject: [PATCH 163/216] remove some output from docstring of SimpleVarInfo --- src/simple_varinfo.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 320e19bf2..3a84e83be 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -82,8 +82,7 @@ ERROR: KeyError: key x[1:2] not found Using `NamedTuple` as underlying storage. ```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )) -SimpleVarInfo((m = (a = [1.0],),), 0.0) +julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); julia> svi_nt[@varname(m)] (a = [1.0],) @@ -106,8 +105,7 @@ ERROR: type NamedTuple has no field b Using `Dict` as underlying storage. ```jldoctest -julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))) -SimpleVarInfo(Dict{VarName{:m, Setfield.IdentityLens}, NamedTuple{(:a,), Tuple{Vector{Float64}}}}(m => (a = [1.0],)), 0.0) +julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], ))); julia> svi_dict[@varname(m)] (a = [1.0],) From 74a978117c0dc2a251a4f68f41b9b57af425a217 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:06:37 +0000 Subject: [PATCH 164/216] fixed haskey and hasvalue for AbstractDict --- src/simple_varinfo.jl | 41 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3a84e83be..4a347c96f 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -242,18 +242,55 @@ end # For `dictlike` we need to check wether `vn` is "immediately" present, or # if some ancestor of `vn` is present in `dictlike`. -hasvalue(dict::AbstractDict, vn::VarName) = haskey(dict, vn) || hasvalue(dict, parent(vn)) -hasvalue(dict::AbstractDict, vn::VarName{<:Any,Setfield.IdentityLens}) = haskey(dict, vn) +function hasvalue(dict::AbstractDict, vn::VarName) + # First we check if `vn` is present as is. + haskey(dict, vn) && return true + + # If `vn` is not present, we check any parent-varnames by attempting + # to split the lens into the key / `parent` and the extraction lens / `child`. + # If `issuccess` is `true`, we found such a split, and hence `vn` is present. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(dict, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + # Return early if no such split could be found. + issuccess || return false + + # At this point we just need to check that we `canview` the value. + value = dict[VarName(vn, keylens)] + + return canview(getlens(vn), value) +end function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) end + +# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with +# same symbol and same type of, say, `IndexLens`, for improvemed `.~` performance. +function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, vals, vns::AbstractVector{<:VarName}) + for (vn, val) in zip(vns, vals) + vi = setindex!!(vi, val, vn) + end + return vi +end + function setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. return SimpleVarInfo(setindex!!(vi.values, val, vn), vi.logp) end +function setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) + for (vn, val) in zip(vns, vals) + vi = setindex!!(vi, val, vn) + end + return vi +end + istrans(::SimpleVarInfo, vn::VarName) = false # Necessary for `matchingvalue` to work properly. From dc28ceb14109335b2b08d1f9fc78ec4ec97b4ebb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:06:58 +0000 Subject: [PATCH 165/216] updated some comments --- src/simple_varinfo.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 4a347c96f..d2591998c 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -197,8 +197,7 @@ function getindex(vi::SimpleVarInfo, vn::VarName) return vi.values[vn] end - # Split the lens into the key / `parent` and the - # extraction lens / `child`. + # Split the lens into the key / `parent` and the extraction lens / `child`. parent, child, issuccess = splitlens(getlens(vn)) do lens l = lens === nothing ? Setfield.IdentityLens() : lens haskey(vi.values, VarName(vn, l)) @@ -207,12 +206,13 @@ function getindex(vi::SimpleVarInfo, vn::VarName) keylens = parent === nothing ? Setfield.IdentityLens() : parent # If we found a valid split, then we can extract the value. - # TODO: Should we also check that we `canview` the extracted `value`? if !issuccess # At this point we just throw an error since the key could not be found. throw(KeyError(vn)) end + # TODO: Should we also check that we `canview` the extracted `value` + # rather than just let it fail upon `get` call? value = vi.values[VarName(vn, keylens)] return get(value, child) end From 4c12498cfc5be36aef86f1f3dee17d4902e6be26 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:07:46 +0000 Subject: [PATCH 166/216] updated some errors --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index fe88f9c1a..bd4cbd7db 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1140,9 +1140,9 @@ function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Se end function push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end val = vectorize(dist, r) From 37a1e2e69b9edbe06b09dfff664b5ad9f8bf6502 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:08:05 +0000 Subject: [PATCH 167/216] added sampling dot_assume for SimpleVarInfo --- src/simple_varinfo.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index d2591998c..3a68c69df 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -391,6 +391,21 @@ function dot_assume( return value, lp, vi end +function dot_assume( + rng, + spl::Union{SampleFromPrior,SampleFromUniform}, + dists::Union{Distribution,AbstractArray{<:Distribution}}, + vns::AbstractArray{<:VarName}, + var::AbstractArray, + vi::SimpleOrThreadSafeSimple, +) + f = (vn, dist) -> init(rng, dist, spl) + value = f.(vns, dists) + vi = setindex!!(vi, value, vns) + lp = sum(Distributions.logpdf.(dists, value)) + return value, lp, vi +end + # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing From 41c2d80c90e1d11e4d842532da567f708d2b2937 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:08:26 +0000 Subject: [PATCH 168/216] added true versions of density computations to TestUtils --- src/test_utils.jl | 115 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 113 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index db2193a73..e9b0e6a7d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,49 @@ using LinearAlgebra using Distributions using Test +""" + logprior_true(model, θ) + +Return the `logprior` of `model` for `θ`. + +This should generally be implemented by hand for every specific `model`. + +See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). +""" +function logprior_true end + +""" + loglikelihood_true(model, θ) + +Return the `loglikelihood` of `model` for `θ`. + +This should generally be implemented by hand for every specific `model`. + +See also: [`logjoint_true`](@ref), [`logprior_true`](@ref). +""" +function loglikelihood_true end + +""" + logjoint_true(model, θ) + +Return the `logjoint` of `model` for `θ`. + +Defaults to `logprior_true(model, θ) + loglikelihood_true(model, θ)`. + +This should generally be implemented by hand for every specific `model` +so that the returned value can be used as a ground-truth for testing things like: + +1. Validity of evaluation of `model` using a particular implementation of `AbstractVarInfo`. +2. Validity of a sampler when combined with DynamicPPL by running the sampler twice: once targeting ground-truth functions, e.g. `logjoint_true`, and once targeting `model`. + +And more. + +See also: [`logprior_true`](@ref), [`loglikelihood_true`](@ref). +""" +function logjoint_true(model::Model, args...) + return logprior_true(model, args...) + loglikelihood_true(model, args...) +end + # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( @@ -17,6 +60,12 @@ using Test x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -30,14 +79,26 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end -@model function demo_assume_multivariate_observe_index(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` m ~ MvNormal(zero(x), I) x ~ MvNormal(m, 0.25 * I) return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) + return logpdf(MvNormal(zero(model.args.x), I), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -51,6 +112,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -61,6 +128,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) + return logpdf(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -69,6 +142,12 @@ end return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) + return logpdf(MvNormal(zeros(2), I), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) + return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -80,6 +159,12 @@ end return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) + return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -88,6 +173,12 @@ end return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) + return logpdf(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) + return logpdf(Normal(m, 0.5), 10.0) +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -105,6 +196,14 @@ end return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m +) + return sum(logpdf.(Normal.(m, 0.5), 10.0)) +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -121,6 +220,12 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} @@ -133,11 +238,17 @@ end return (; m=m, x=x, logp=getlogp(__varinfo__)) end +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) + return loglikelihood(Normal(), m) +end +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) + return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), - demo_assume_multivariate_observe_index(), + demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), demo_assume_dot_observe(), demo_assume_observe_literal(), From 35445ba9a4d57c0237fb5f0b1d3e3d4768d11c3a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 01:09:07 +0000 Subject: [PATCH 169/216] added tests specific for SimpleVarInfo --- test/runtests.jl | 2 ++ test/simple_varinfo.jl | 70 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 test/simple_varinfo.jl diff --git a/test/runtests.jl b/test/runtests.jl index bb2ae579c..3372b6371 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using MacroTools using MCMCChains using Tracker using Zygote +using Setfield using Distributed using LinearAlgebra @@ -34,6 +35,7 @@ include("test_util.jl") include("utils.jl") include("compiler.jl") include("varinfo.jl") + include("simple_varinfo.jl") include("model.jl") include("sampler.jl") include("prob_macro.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl new file mode 100644 index 000000000..580d2903d --- /dev/null +++ b/test/simple_varinfo.jl @@ -0,0 +1,70 @@ +@testset "simple_varinfo.jl" begin + @testset "constructor & indexing" begin + @testset "NamedTuple" begin + svi = SimpleVarInfo(; m=1.0) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(; m=[1.0]) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo{Float32}(; m=1.0) + @test getlogp(svi) isa Float32 + + svi = SimpleVarInfo((m=1.0,), 1.0) + @test getlogp(svi) == 1.0 + end + + @testset "Dict" begin + svi = SimpleVarInfo(Dict(@varname(m) => 1.0)) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test !haskey(svi, @varname(m[1])) + + svi = SimpleVarInfo(Dict(@varname(m) => [1.0])) + @test getlogp(svi) == 0.0 + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m[1])) + @test !haskey(svi, @varname(m[2])) + @test svi[@varname(m)][1] == svi[@varname(m[1])] + end + end + + @testset "SimpleVarInfo on $(model.name)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # We might need to pre-allocate for the variable `m`, so we need + # to see whether this is the case. + m = model().m + svi = if m isa AbstractArray + SimpleVarInfo((m=similar(m),)) + else + SimpleVarInfo() + end + + # Sample a new varinfo! + _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + # Type of realization for `m` should be unchanged. + @test typeof(svi_new[@varname(m)]) === typeof(m) + # Realization for `m` should be different wp. 1. + @test svi_new[@varname(m)] != m + # Logjoint should be non-zero wp. 1. + @test getlogp(svi_new) != 0 + + # Evaluation. + m_eval = if m isa AbstractArray + randn!(similar(m)) + else + randn(eltype(m)) + end + svi_eval = @set svi_new.values.m = m_eval + svi_eval = DynamicPPL.resetlogp!!(svi_eval) + + logπ = logjoint(model, svi_eval) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, svi_eval.values.m) + @test logπ ≈ logπ_true + end +end From acdedc78cf526408b1b624e5a4760860e8c48042 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 02:58:05 +0000 Subject: [PATCH 170/216] also document TestUtils --- docs/src/index.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/index.md b/docs/src/index.md index f7054b28b..596cdd977 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,3 +3,9 @@ ```@autodocs Modules = [DynamicPPL] ``` + +# DynamicPPL.TestUtils + +```@autodocs +Modules = [DynamicPPL.TestUtils] +``` From 95f67c0f9b94e234634f0ab49fc1ec2390a5c6ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 03:02:58 +0000 Subject: [PATCH 171/216] added TestUtils to docs --- docs/make.jl | 2 +- docs/src/index.md | 6 ------ docs/src/test_utils.md | 5 +++++ 3 files changed, 6 insertions(+), 7 deletions(-) create mode 100644 docs/src/test_utils.md diff --git a/docs/make.jl b/docs/make.jl index 7c4735ff2..3076dc4ff 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,7 +8,7 @@ makedocs(; sitename="DynamicPPL", format=Documenter.HTML(), modules=[DynamicPPL], - pages=["Home" => "index.md"], + pages=["Home" => "index.md", "TestUtils" => "test_utils.md"], strict=true, checkdocs=:exports, doctestfilters=[ diff --git a/docs/src/index.md b/docs/src/index.md index 596cdd977..f7054b28b 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -3,9 +3,3 @@ ```@autodocs Modules = [DynamicPPL] ``` - -# DynamicPPL.TestUtils - -```@autodocs -Modules = [DynamicPPL.TestUtils] -``` diff --git a/docs/src/test_utils.md b/docs/src/test_utils.md new file mode 100644 index 000000000..912bd7c51 --- /dev/null +++ b/docs/src/test_utils.md @@ -0,0 +1,5 @@ +# DynamicPPL.TestUtils + +```@autodocs +Modules = [DynamicPPL.TestUtils] +``` From 8b5729cf6ec46de4cbb0450fee4f7e558bde8e9f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 03:59:47 +0000 Subject: [PATCH 172/216] fixed setindex!! for SimpleVarInfo using AbstractDict --- src/simple_varinfo.jl | 27 ++++++++++++++++++++++++--- src/utils.jl | 7 ++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 3a68c69df..a1643fbb7 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -262,7 +262,7 @@ function hasvalue(dict::AbstractDict, vn::VarName) # At this point we just need to check that we `canview` the value. value = dict[VarName(vn, keylens)] - return canview(getlens(vn), value) + return canview(child, value) end function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) @@ -279,9 +279,30 @@ function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, vals, vns::AbstractVector{< return vi end -function setindex!!(vi::SimpleVarInfo, val, vn::VarName) +function setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. - return SimpleVarInfo(setindex!!(vi.values, val, vn), vi.logp) + dict = values(vi) + # Attempt to split into `parent` and `child` lenses. + parent, child, issuccess = splitlens(getlens(vn)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + haskey(dict, VarName(vn, l)) + end + # When combined with `VarInfo`, `nothing` is equivalent to `IdentityLens`. + keylens = parent === nothing ? Setfield.IdentityLens() : parent + + dict_new = if !issuccess + # Split doesn't exist ⟹ we're working with a new key. + setindex!!(dict, val, vn) + else + # Split exists ⟹ trying to set an existing key. + vn_key = VarName(vn, keylens) + setindex!!( + dict, + set!!(dict[vn_key], child, val), + vn_key + ) + end + return SimpleVarInfo(dict_new, vi.logp) end function setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) diff --git a/src/utils.jl b/src/utils.jl index 9bf52efc1..ffdc21070 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -159,6 +159,10 @@ collectmaybe(x::Base.AbstractSet) = collect(x) ####################### # BangBang.jl related # ####################### +function set!!(obj, lens::Setfield.Lens, value) + lensmut = BangBang.prefermutation(lens) + return Setfield.set(obj, lensmut, value) +end function set!!(obj, vn::VarName{sym}, value) where {sym} lens = BangBang.prefermutation(Setfield.PropertyLens{sym}() ∘ AbstractPPL.getlens(vn)) return Setfield.set(obj, lens, value) @@ -193,8 +197,9 @@ false canview(lens, container) = false canview(::Setfield.IdentityLens, _) = true function canview(lens::Setfield.PropertyLens{field}, x) where {field} - return haskey(x, field) + return hasproperty(x, field) end + # `IndexLens`: only relevant if `x` supports indexing. canview(lens::Setfield.IndexLens, x) = false canview(lens::Setfield.IndexLens, x::AbstractArray) = checkbounds(Bool, x, lens.indices...) From 04dab05f8838b2c4addb9ed7d03952503919e69b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:00:07 +0000 Subject: [PATCH 173/216] added more tests --- test/simple_varinfo.jl | 100 ++++++++++++++++++++++++++++++++--------- 1 file changed, 79 insertions(+), 21 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 580d2903d..9e596af5f 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -13,6 +13,13 @@ @test !haskey(svi, @varname(m[2])) @test svi[@varname(m)][1] == svi[@varname(m[1])] + svi = SimpleVarInfo(; m=(a=[1.0],)) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + svi = SimpleVarInfo{Float32}(; m=1.0) @test getlogp(svi) isa Float32 @@ -32,6 +39,22 @@ @test haskey(svi, @varname(m[1])) @test !haskey(svi, @varname(m[2])) @test svi[@varname(m)][1] == svi[@varname(m[1])] + + svi = SimpleVarInfo(Dict(@varname(m) => (a=[1.0],))) + @test haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) + + svi = SimpleVarInfo(Dict(@varname(m.a) => [1.0])) + # Now we only have a variable `m.a` which is subsumed by `m`, + # but we can't guarantee that we have the "entire" `m`. + @test !haskey(svi, @varname(m)) + @test haskey(svi, @varname(m.a)) + @test haskey(svi, @varname(m.a[1])) + @test !haskey(svi, @varname(m.a[2])) + @test !haskey(svi, @varname(m.a.b)) end end @@ -39,32 +62,67 @@ # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. m = model().m - svi = if m isa AbstractArray + svi_nt = if m isa AbstractArray SimpleVarInfo((m=similar(m),)) else SimpleVarInfo() end + svi_dict = SimpleVarInfo(VarInfo(model), Dict) - # Sample a new varinfo! - _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) - # Type of realization for `m` should be unchanged. - @test typeof(svi_new[@varname(m)]) === typeof(m) - # Realization for `m` should be different wp. 1. - @test svi_new[@varname(m)] != m - # Logjoint should be non-zero wp. 1. - @test getlogp(svi_new) != 0 - - # Evaluation. - m_eval = if m isa AbstractArray - randn!(similar(m)) - else - randn(eltype(m)) - end - svi_eval = @set svi_new.values.m = m_eval - svi_eval = DynamicPPL.resetlogp!!(svi_eval) + @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + # Random seed is set in each `@testset`, so we need to sample + # a new realization for `m` here. + m = model().m - logπ = logjoint(model, svi_eval) - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, svi_eval.values.m) - @test logπ ≈ logπ_true + ### Sampling ### + # Sample a new varinfo! + _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) + + # If the `m[1]` varname doesn't exist, this is a univariate model. + # TODO: Find a better way of dealing with this that is not dependent + # on knowledge of internals of `model`. + isunivariate = !haskey(svi_new, @varname(m[1])) + + # Realization for `m` should be different wp. 1. + if isunivariate + @test svi_new[@varname(m)] != m + else + @test svi_new[@varname(m[1])] != m[1] + @test svi_new[@varname(m[2])] != m[2] + end + # Logjoint should be non-zero wp. 1. + @test getlogp(svi_new) != 0 + + ### Evaluation ### + # Sample some random testing values. + m_eval = if m isa AbstractArray + randn!(similar(m)) + else + randn(eltype(m)) + end + + # Update the realizations in `svi_new`. + svi_eval = if isunivariate + DynamicPPL.setindex!!(svi_new, m_eval, @varname(m)) + else + DynamicPPL.setindex!!(svi_new, m_eval, [@varname(m[1]), @varname(m[2])]) + end + # Reset the logp field. + svi_eval = DynamicPPL.resetlogp!!(svi_eval) + + # Compute `logjoint` using the varinfo. + logπ = logjoint(model, svi_eval) + # Extract the parameters from `svi_eval`. + m_vi = if isunivariate + svi_eval[@varname(m)] + else + svi_eval[[@varname(m[1]), @varname(m[2])]] + end + # These should not have changed. + @test m_vi == m_eval + # Compute the true `logjoint` and compare. + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_vi) + @test logπ ≈ logπ_true + end end end From 251eb80ed1d1821307dccaaecd5ecb160fd0a6c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:00:36 +0000 Subject: [PATCH 174/216] formatting --- src/simple_varinfo.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a1643fbb7..197f3f642 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -296,11 +296,7 @@ function setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) else # Split exists ⟹ trying to set an existing key. vn_key = VarName(vn, keylens) - setindex!!( - dict, - set!!(dict[vn_key], child, val), - vn_key - ) + setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end return SimpleVarInfo(dict_new, vi.logp) end From 142d93b2dd960d45a429e5577343bd98da5aac99 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:15:29 +0000 Subject: [PATCH 175/216] dont use BangBang for setall! --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index bd4cbd7db..7ac335776 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -330,15 +330,15 @@ getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) end """ - setall!!(vi::VarInfo, val) + setall!(vi::VarInfo, val) Set the values of all the variables in `vi` to `val`, mutating if it makese sense. The values may or may not be transformed to Euclidean space. """ -setall!!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) +setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val +setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} expr = Expr(:block) start = :(1) @@ -973,7 +973,7 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!!(vi, val) +setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` From 3c91d2fcf677ef5d0356937853c12ec6a276ce1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:16:23 +0000 Subject: [PATCH 176/216] revert unnecessary changes to settrans! --- src/varinfo.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 7ac335776..a5b02f0d3 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -364,13 +364,7 @@ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - if trans - set_flag!(vi, vn, "trans") - else - unset_flag!(vi, vn, "trans") - end - - return nothing + return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") end """ From 1ffb83cf0235ec49dfb4d6a3e0eb7dcb2681ff26 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:16:54 +0000 Subject: [PATCH 177/216] revert unnecessary changes to set_flag! --- src/varinfo.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index a5b02f0d3..0c267db02 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -512,8 +512,7 @@ end Set `vn`'s value for `flag` to `true` in `vi`. """ function set_flag!(vi::VarInfo, vn::VarName, flag::String) - getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true - return vi + return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end #### From 871b8cdd067bfc2ecb03912a3933be6470927e39 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 04:18:45 +0000 Subject: [PATCH 178/216] revert some changes to docstrings --- src/varinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0c267db02..0c76a754d 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -361,7 +361,7 @@ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ settrans!(vi::VarInfo, trans::Bool, vn::VarName) -Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. +Set the `trans` flag value of `vn` in `vi`. """ function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") @@ -592,7 +592,7 @@ TypedVarInfo(vi::TypedVarInfo) = vi empty!!(vi::VarInfo) Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros, mutating if it makes sense. +zeros. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ From df4e514b6993a2b3fc11933a167c64a443e0867c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 10:30:12 +0000 Subject: [PATCH 179/216] fixed some comments and docstrings --- src/simple_varinfo.jl | 2 +- src/varinfo.jl | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 197f3f642..83d561d74 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -271,7 +271,7 @@ function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improvemed `.~` performance. +# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, vals, vns::AbstractVector{<:VarName}) for (vn, val) in zip(vns, vals) vi = setindex!!(vi, val, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0c76a754d..b94998e9e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -332,8 +332,7 @@ end """ setall!(vi::VarInfo, val) -Set the values of all the variables in `vi` to `val`, -mutating if it makese sense. +Set the values of all the variables in `vi` to `val`. The values may or may not be transformed to Euclidean space. """ From 037a8c9048f5975c507be6008c43d5507ff20083 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 10:59:52 +0000 Subject: [PATCH 180/216] added more convenient logjoint, logprior, and loglikelihood methods --- src/simple_varinfo.jl | 102 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 83d561d74..26ae92d6a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -435,3 +435,105 @@ Return the values/realizations in `varinfo` as `Type`, if implemented. values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values + +""" + logjoint(model::Model, θ) + +Return the log joint probability of variables `θ` for the probabilistic `model`. + +See [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `Dict`. + logjoint(demo([1.0]), Dict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) + return logjoint(model, SimpleVarInfo(θ)) +end + +""" + logprior(model::Model, θ) + +Return the log prior probability of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `Dict`. + logprior(demo([1.0]), Dict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) + return logprior(model, SimpleVarInfo(θ)) +end + +""" + loglikelihood(model::Model, θ) + +Return the log likelihood of variables `θ` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `Dict`. + loglikelihood(demo([1.0]), Dict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) + return Distributions.loglikelihood(model, SimpleVarInfo(θ)) +end From 7a9af3c8fb45fa14be979537c4c0c41f2b8292f2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:26:14 +0000 Subject: [PATCH 181/216] removed unnecessary export --- src/DynamicPPL.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c739a859e..312ec2f4b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -57,7 +57,7 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - set_flag!!, + set_flag!, setgid!, updategid!, setorder!, From 71bd8bcb924e5458b8c189dccf39d894db77d87f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Nov 2021 11:26:32 +0000 Subject: [PATCH 182/216] fixed export --- src/DynamicPPL.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 312ec2f4b..c054ea747 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -57,7 +57,6 @@ export AbstractVarInfo, is_flagged, set_flag!, unset_flag!, - set_flag!, setgid!, updategid!, setorder!, From 969bb6598fc58c06110838ec1b0f953ceed3a3d1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 20 Nov 2021 03:34:51 +0000 Subject: [PATCH 183/216] use the Setfield impl of getindex, etc. as default and specialize on AbstractDict --- src/simple_varinfo.jl | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 26ae92d6a..44deaab52 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -187,12 +187,12 @@ function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) end # `NamedTuple` -function getindex(vi::SimpleVarInfo{<:NamedTuple}, vn::VarName) +function getindex(vi::SimpleVarInfo, vn::VarName) return get(vi.values, vn) end # `Dict` -function getindex(vi::SimpleVarInfo, vn::VarName) +function getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) if haskey(vi.values, vn) return vi.values[vn] end @@ -265,14 +265,14 @@ function hasvalue(dict::AbstractDict, vn::VarName) return canview(child, value) end -function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, val, vn::VarName) +function setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function setindex!!(vi::SimpleVarInfo{<:NamedTuple}, vals, vns::AbstractVector{<:VarName}) +function setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) for (vn, val) in zip(vns, vals) vi = setindex!!(vi, val, vn) end @@ -301,15 +301,6 @@ function setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) return SimpleVarInfo(dict_new, vi.logp) end -function setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = setindex!!(vi, val, vn) - end - return vi -end - -istrans(::SimpleVarInfo, vn::VarName) = false - # Necessary for `matchingvalue` to work properly. function Base.eltype( vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} @@ -426,6 +417,7 @@ end # HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing +istrans(::SimpleVarInfo, vn::VarName) = false """ values(varinfo, Type) From 92dd5b8b7418413695ee25fe09ed0c00bd05328f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 22 Nov 2021 14:11:11 +0000 Subject: [PATCH 184/216] fixed docstrings of logjoint, etc. --- src/simple_varinfo.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 44deaab52..60a6e6889 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -458,7 +458,7 @@ julia> # Truth. -9902.33787706641 ``` """ -function logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) +function logjoint(model::Model, θ) return logjoint(model, SimpleVarInfo(θ)) end @@ -492,7 +492,7 @@ julia> # Truth. -5000.918938533205 ``` """ -function logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) +function logprior(model::Model, θ) return logprior(model, SimpleVarInfo(θ)) end @@ -526,6 +526,6 @@ julia> # Truth. -4901.418938533205 ``` """ -function Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) +function Distributions.loglikelihood(model::Model, θ) return Distributions.loglikelihood(model, SimpleVarInfo(θ)) end From de936a5bb6176692d4e0ce0759c228bfb081c930 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 23 Nov 2021 13:31:15 +0000 Subject: [PATCH 185/216] Apply suggestions from code review Co-authored-by: Philipp Gabler --- src/compiler.jl | 2 +- src/simple_varinfo.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fe7810b05..1a89dd18e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -540,7 +540,7 @@ Suppose the following is the return-value: return x ~ Normal() ``` -Without `return_values`, once expanded in `generated_mainbody!`, this would be +Without `return_values`, once expanded in `generate_mainbody!`, this would be ```julia return (x, __varinfo__ = tilde_assume!!(...)), __varinfo__ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 60a6e6889..78f9d8f42 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -12,7 +12,7 @@ The major differences between this and `TypedVarInfo` are: 2. `SimpleVarInfo` can use more efficient bijectors. 3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either a) no indexing is used in tilde-statements, or - b) the values have been specified with the corret shapes. + b) the values have been specified with the correct shapes. # Examples ## General usage From 147e9f5739347f950fed974ef8617bc5f277ef1e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Nov 2021 15:32:34 +0000 Subject: [PATCH 186/216] fixed docstring for model --- src/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 606581d51..521952b49 100644 --- a/src/model.jl +++ b/src/model.jl @@ -195,7 +195,7 @@ julia> @model demo_inner() = m ~ Normal() demo_inner (generic function with 2 methods) julia> @model function demo_outer() - m = @submodel demo_inner() + @submodel m = demo_inner() return m end demo_outer (generic function with 2 methods) @@ -215,7 +215,7 @@ But one needs to be careful when prefixing variables in the nested models: ```jldoctest condition julia> @model function demo_outer_prefix() - m = @submodel inner demo_inner() + @submodel inner m = demo_inner() return m end demo_outer_prefix (generic function with 2 methods) From aeb4fa1f32913687fc6e50728eb08174731f3e99 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Nov 2021 15:37:24 +0000 Subject: [PATCH 187/216] replaced return_values by capturing return-value from tilde-statements instead --- src/compiler.jl | 77 ++++++++++------------------------ src/context_implementations.jl | 2 - 2 files changed, 23 insertions(+), 56 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 1a89dd18e..6bae085c2 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -355,10 +355,12 @@ end function generate_tilde_literal(left, right) # If the LHS is a literal, it is always an observation + value = gensym(:value) return quote - _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) + $value end end @@ -373,7 +375,7 @@ function generate_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption + @gensym vn isassumption value # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact # that in DynamicPPL we the entire function body. Instead we should be @@ -389,13 +391,14 @@ function generate_tilde(left, right) $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end - _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( + $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, __varinfo__, ) + $value end end end @@ -404,8 +407,8 @@ function generate_tilde_assume(left, right, vn) # HACK: Because the Setfield.jl macro does not support assignment # with multiple arguments on the LHS, we need to capture the return-values # and then update them LHS variables one by one. - tmp = gensym(:tmp) - expr = :($left = $first($tmp)) + value = gensym(:value) + expr = :($left = $value) if left isa Expr expr = AbstractPPL.drop_escape( Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) @@ -413,14 +416,13 @@ function generate_tilde_assume(left, right, vn) end return quote - $tmp = $(DynamicPPL.tilde_assume!!)( + ($value, __varinfo__) = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, ) $expr - __varinfo__ = $last($tmp) - $left, __varinfo__ + $value end end @@ -434,7 +436,7 @@ function generate_dot_tilde(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn isassumption + @gensym vn isassumption value return quote $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) @@ -446,13 +448,14 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end - _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( + $value, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, __varinfo__, ) + $value end end end @@ -461,15 +464,18 @@ function generate_dot_tilde_assume(left, right, vn) # We don't need to use `Setfield.@set` here since # `.=` is always going to be inplace + needs `left` to # be something that supports `.=`. - return :( - ($left, __varinfo__) = $(DynamicPPL.dot_tilde_assume!!)( + value = gensym(:value) + return quote + ($value, __varinfo__) = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn )..., __varinfo__, ) - ) + $left .= $value + $value + end end """ @@ -518,49 +524,12 @@ function replace_returns(e::Expr) e.args[1] end - return :(return $(DynamicPPL.return_values)($retval_expr, __varinfo__)) + return :(return ($retval_expr, __varinfo__)) end return Expr(e.head, map(replace_returns, e.args)...) end -""" - return_values(retval, varinfo) - -Return `(retval, varinfo)` if `retval` is not a `Tuple` with second -component being a `AbstractVarInfo`. - -Used together with [`replace_returns`](@ref), it handles the following case. - -# Example - -Suppose the following is the return-value: - -```julia -return x ~ Normal() -``` - -Without `return_values`, once expanded in `generate_mainbody!`, this would be - -```julia -return (x, __varinfo__ = tilde_assume!!(...)), __varinfo__ -``` - -i.e. the return-value of the model would end up `(x, __varinfo__), __varinfo__` -which in turn would lead to a `(::Model)(args...)` call returning `(x, __varinfo__)`, -breaking with the expectation of the user. - -In such a scenario `return_values` effectively results in the following - -```julia -return x, __varinfo__ = tilde_assume!!(...) -``` - -preserving user expectation, as desired. -""" -return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) -return_values(retval::Tuple{<:Any,<:AbstractVarInfo}, ::AbstractVarInfo) = retval - # If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) @@ -604,11 +573,11 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. # We use `MacroTools.@q begin ... end` instead of regular `quote ... end` to ensure # that no new `LineNumberNode`s are added apart from the reference `linenumbernode` - # to the call site + # to the call site. # NOTE: We need to replace statements of the form `return ...` with - # `return DynamicPPL.return_values(..., __varinfo__)` to ensure that the second + # `return (..., __varinfo__)` to ensure that the second # element in the returned value is always the most up-to-date `__varinfo__`. - # See the docstrings of `replace_returns` and `return_values` for more info. + # See the docstrings of `replace_returns` for more info. evaluatordef[:body] = MacroTools.@q begin $(linenumbernode) $(replace_returns(make_returns_explicit!(modelinfo[:body]))) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ce2a61a67..7ded8596d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -373,8 +373,6 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - # Mutation of `value` no longer occurs in main body, so we do it here. - left .= value return value, acclogp!!(vi, logp), vi end From abb07b3dc3dbf9a49b21f02aecec5c659d990de3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 26 Nov 2021 15:38:10 +0000 Subject: [PATCH 188/216] added some tests for return-value of model --- test/compiler.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/compiler.jl b/test/compiler.jl index ef6d8f36f..80fab9669 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -544,4 +544,26 @@ end f(::Model{typeof(demo),()}) = false @test !f(demo()) end + + @testset "return value" begin + # Even if the return-value is `AbstractVarInfo`, we should return + # a `Tuple` with `AbstractVarInfo` in the second component too. + @model demo() = return __varinfo__ + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + @test svi == SimpleVarInfo() + @test retval == svi + + # We should not be altering return-values other than at top-level. + @model function demo() + # If we also replaced this `return` inside of `f`, then the + # final `return` would be include `__varinfo__`. + f(x) = return x^2 + return f(1.0) + end + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + @test retval isa Float64 + + @model demo() = x ~ Normal() + retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) + end end From 25cacd32fab09cabb6342a251706428e595a8a63 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 2 Dec 2021 20:30:47 +0000 Subject: [PATCH 189/216] added broadcast_foreach --- src/utils.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index ffdc21070..77e94c053 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -314,3 +314,12 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end + +################################## +### Generally useful functions ### +################################## +function broadcast_foreach(op, args...) + bc = Base.broadcasted(op, args...) + foreach(identity, bc) + return nothing +end From 87ce03b6ebdde98e356a1f4ae4978492472f9907 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:19:11 +0000 Subject: [PATCH 190/216] Apply suggestions from @devmotion Co-authored-by: David Widmann --- src/compiler.jl | 17 ++++++--------- src/context_implementations.jl | 2 +- src/loglikelihoods.jl | 10 +++++---- src/model.jl | 22 ++++++------------- src/prob_macro.jl | 3 +-- src/simple_varinfo.jl | 40 +++++++++++----------------------- src/submodel_macro.jl | 4 ++-- 7 files changed, 37 insertions(+), 61 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6bae085c2..4146b3f3d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -355,7 +355,7 @@ end function generate_tilde_literal(left, right) # If the LHS is a literal, it is always an observation - value = gensym(:value) + @gensym value return quote $value, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ @@ -406,8 +406,8 @@ end function generate_tilde_assume(left, right, vn) # HACK: Because the Setfield.jl macro does not support assignment # with multiple arguments on the LHS, we need to capture the return-values - # and then update them LHS variables one by one. - value = gensym(:value) + # and then update the LHS variables one by one. + @gensym value expr = :($left = $value) if left isa Expr expr = AbstractPPL.drop_escape( @@ -464,7 +464,7 @@ function generate_dot_tilde_assume(left, right, vn) # We don't need to use `Setfield.@set` here since # `.=` is always going to be inplace + needs `left` to # be something that supports `.=`. - value = gensym(:value) + @gensym value return quote ($value, __varinfo__) = $(DynamicPPL.dot_tilde_assume!!)( __context__, @@ -508,7 +508,6 @@ Note that this method will _not_ replace `return` statements within function definitions. This is checked using [`isfuncdef`](@ref). """ replace_returns(e) = e -replace_returns(e::Symbol) = e function replace_returns(e::Expr) if isfuncdef(e) return e @@ -534,12 +533,10 @@ end make_returns_explicit!(body) = Expr(:return, body) function make_returns_explicit!(body::Expr) # If the last statement is a return-statement, we don't do anything. - if Meta.isexpr(body.args[end], :return) - return body - end - # Otherwise we replace the last statement with a `return` statement. - body.args[end] = Expr(:return, body.args[end]) + if !Meta.isexpr(body.args[end], :return) + body.args[end] = Expr(:return, body.args[end]) + end return body end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 7ded8596d..3419af957 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -509,7 +509,7 @@ function get_and_set_val!( # TODO: This will inefficient since it will allocate an entire vector. # We could either: # 1. Figure out the broadcast size and use a `foreach`. - # 2. Define a anonynous function which returns `nothing`, which + # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. push!!.(Ref(vi), vns, r, dists, Ref(spl)) settrans!.(Ref(vi), false, vns) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 421292b05..daf05eedd 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -109,16 +109,18 @@ end # the `dot_assume` implementations, but as things are _right now_ this is the best we can do. function _pointwise_tilde_observe(context, right, left, vi) # We need to drop the `vi` returned. - observe_logps(r, l) = first(tilde_observe(context, r, l, vi)) - return observe_logps.(right, left) + return broadcast(right, left) do r, l + return first(tilde_observe(context, r, l, vi)) + end end function _pointwise_tilde_observe( context, right::MultivariateDistribution, left::AbstractMatrix, vi::AbstractVarInfo ) # We need to drop the `vi` returned. - observe_logps(l) = first(tilde_observe(context, right, l, vi)) - return observe_logps.(eachcol(left)) + return map(eachcol(left)) do l + return first(tilde_observe(context, right, l, vi)) + end end """ diff --git a/src/model.jl b/src/model.jl index 521952b49..d3eb421a1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -374,7 +374,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -(model::Model)(args...) = (first ∘ evaluate!!)(model, args...) +(model::Model)(args...) = first(evaluate!!(model, args...)) """ evaluate!!(model::Model[, rng, varinfo, sampler, context]) @@ -433,10 +433,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -function evaluate_threadunsafe!!(model, varinfo, context) - varinfo = resetlogp!!(varinfo) - return _evaluate!!(model, varinfo, context) -end +evaluate_threadunsafe!!(model, varinfo, context) = _evaluate!!(model, resetlogp!!(varinfo), context) """ evaluate_threadsafe!!(model, varinfo, context) @@ -450,11 +447,9 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe!!`](@ref) """ function evaluate_threadsafe!!(model, varinfo, context) - varinfo = resetlogp!!(varinfo) - wrapper = ThreadSafeVarInfo(varinfo) + wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) - varinfo = setlogp!!(varinfo, getlogp(wrapper_new)) - return result, varinfo + return result, setlogp!!(varinfo, getlogp(wrapper_new)) end """ @@ -512,8 +507,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate!!(model, varinfo, DefaultContext()) - return getlogp(varinfo_new) + return getlogp(last(evaluate!!(model, varinfo, DefaultContext()))) end """ @@ -524,8 +518,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate!!(model, varinfo, PriorContext()) - return getlogp(varinfo_new) + return getlogp(last(evaluate!!(model, varinfo, PriorContext()))) end """ @@ -536,8 +529,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - _, varinfo_new = evaluate!!(model, varinfo, LikelihoodContext()) - return getlogp(varinfo_new) + return getlogp(last(evaluate!!(model, varinfo, LikelihoodContext()))) end """ diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 1eca69d43..21a674fc9 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,8 +146,7 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - _, vi = DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)) - return getlogp(vi) + return getlogp(last(DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)))) end @generated function make_prior_model( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 78f9d8f42..f6434c11d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -139,8 +139,8 @@ SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ) # Constructor from `Model`. SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} - svi = last(DynamicPPL.evaluate!!(model, SimpleVarInfo{T}(), args...)) - return svi + return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) +end end # Constructor from `VarInfo`. @@ -179,20 +179,14 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) end function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo) - print(io, "SimpleVarInfo(") - print(io, svi.values) - print(io, ", ") - print(io, svi.logp) - return print(io, ")") + return print(io, "SimpleVarInfo(", svi.values, ", ", svi.logp, ")") end # `NamedTuple` -function getindex(vi::SimpleVarInfo, vn::VarName) - return get(vi.values, vn) -end +Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) # `Dict` -function getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) +function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName) if haskey(vi.values, vn) return vi.values[vn] end @@ -219,16 +213,14 @@ end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. -function getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(vn -> getindex(vi, vn), vns) -end +Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(Base.Fix1(getindex, vi), vns) # HACK: Needed to disambiguiate. -getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(vn -> getindex(vi, vn), vns) +Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) -getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values -getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? -getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values +Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) @@ -458,9 +450,7 @@ julia> # Truth. -9902.33787706641 ``` """ -function logjoint(model::Model, θ) - return logjoint(model, SimpleVarInfo(θ)) -end +logjoint(model::Model, θ) = logjoint(model, SimpleVarInfo(θ)) """ logprior(model::Model, θ) @@ -492,9 +482,7 @@ julia> # Truth. -5000.918938533205 ``` """ -function logprior(model::Model, θ) - return logprior(model, SimpleVarInfo(θ)) -end +logprior(model::Model, θ) = logprior(model, SimpleVarInfo(θ)) """ loglikelihood(model::Model, θ) @@ -526,6 +514,4 @@ julia> # Truth. -4901.418938533205 ``` """ -function Distributions.loglikelihood(model::Model, θ) - return Distributions.loglikelihood(model, SimpleVarInfo(θ)) -end +Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index a399fc4c9..a18738882 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -120,9 +120,9 @@ function submodel(expr, ctx=esc(:__context__)) return if args_assign === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:_)), $(esc(:__varinfo__)) = _evaluate!!( + $(esc(:__varinfo__)) = last(_evaluate!!( $(esc(expr)), $(esc(:__varinfo__)), $(ctx) - ) + )) end else # Here we also want the return-variable. From 4bdeca3ad853272d273e226f9e633274cd8aed10 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:20:19 +0000 Subject: [PATCH 191/216] remove broadcast_foreach for now --- src/utils.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 77e94c053..ffdc21070 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -314,12 +314,3 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end - -################################## -### Generally useful functions ### -################################## -function broadcast_foreach(op, args...) - bc = Base.broadcasted(op, args...) - foreach(identity, bc) - return nothing -end From 80ee8f4363d1bae91e9b8a22d1feba238b7abcec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:34:51 +0000 Subject: [PATCH 192/216] some fixes to ThreadSafeVarInfo --- src/threadsafe.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 2ecfaf07d..f0b361cc6 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -81,7 +81,7 @@ end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) function empty!!(vi::ThreadSafeVarInfo) - empty!!(vi.varinfo) + Setfield.@set! vi.varinfo = empty!!(vi.varinfo) fill!(vi.logps, zero(getlogp(vi))) return vi end @@ -89,7 +89,7 @@ end function push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return push!!(vi.varinfo, vn, r, dist, gidset) + return @set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) end function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) From 496ec3d279c3462a9050fc3a8b945abe9f9278be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:35:52 +0000 Subject: [PATCH 193/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 4146b3f3d..f39049d6e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -416,7 +416,7 @@ function generate_tilde_assume(left, right, vn) end return quote - ($value, __varinfo__) = $(DynamicPPL.tilde_assume!!)( + $value, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., __varinfo__, @@ -466,7 +466,7 @@ function generate_dot_tilde_assume(left, right, vn) # be something that supports `.=`. @gensym value return quote - ($value, __varinfo__) = $(DynamicPPL.dot_tilde_assume!!)( + $value, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn From 9c33e67e3c8b4545ccc6b72ad1abd56433821380 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:49:22 +0000 Subject: [PATCH 194/216] fixed docstrings --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3419af957..a033b352e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -108,7 +108,7 @@ end tilde_assume!!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value. +accumulate the log probability, and return the sampled value and the update `vi`. By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. @@ -161,7 +161,7 @@ end tilde_observe!!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value. +accumulate the log probability, and return the observed value and the updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. From 6ad0e43d1cdc506f96dd79b0c0c4725986167d90 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:49:28 +0000 Subject: [PATCH 195/216] forgot qualification for set --- src/threadsafe.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f0b361cc6..c0e8e64c4 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -89,7 +89,7 @@ end function push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return @set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) + return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) end function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) From 0d217c47ccd0fa7f1b0b4914f8e4802bc23c87e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 02:50:02 +0000 Subject: [PATCH 196/216] formatting --- src/model.jl | 4 +++- src/prob_macro.jl | 4 +++- src/simple_varinfo.jl | 5 +++-- src/submodel_macro.jl | 6 +++--- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/model.jl b/src/model.jl index d3eb421a1..db7e1e265 100644 --- a/src/model.jl +++ b/src/model.jl @@ -433,7 +433,9 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe!!`](@ref) """ -evaluate_threadunsafe!!(model, varinfo, context) = _evaluate!!(model, resetlogp!!(varinfo), context) +function evaluate_threadunsafe!!(model, varinfo, context) + return _evaluate!!(model, resetlogp!!(varinfo), context) +end """ evaluate_threadsafe!!(model, varinfo, context) diff --git a/src/prob_macro.jl b/src/prob_macro.jl index 21a674fc9..c87e365ea 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,7 +146,9 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - return getlogp(last(DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left)))) + return getlogp( + last(DynamicPPL.evaluate!!(model, vi, SampleFromPrior(), PriorContext(left))) + ) end @generated function make_prior_model( diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f6434c11d..0cd86a7e6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -141,7 +141,6 @@ SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...) function SimpleVarInfo{T}(model::Model, args...) where {T<:Real} return last(evaluate!!(model, SimpleVarInfo{T}(), args...)) end -end # Constructor from `VarInfo`. function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D} @@ -213,7 +212,9 @@ end # `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than # just `Vector`. -Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) = map(Base.Fix1(getindex, vi), vns) +function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(getindex, vi), vns) +end # HACK: Needed to disambiguiate. Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index a18738882..54477a05d 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -120,9 +120,9 @@ function submodel(expr, ctx=esc(:__context__)) return if args_assign === nothing # In this case we only want to get the `__varinfo__`. quote - $(esc(:__varinfo__)) = last(_evaluate!!( - $(esc(expr)), $(esc(:__varinfo__)), $(ctx) - )) + $(esc(:__varinfo__)) = last( + _evaluate!!($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + ) end else # Here we also want the return-variable. From 57dda81d92672572a1b0b672856b0291769bcdf0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 12:33:00 +0000 Subject: [PATCH 197/216] added comment about why we cant use MacroTools.isdef --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index f39049d6e..f973a3bb6 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -478,6 +478,8 @@ function generate_dot_tilde_assume(left, right, vn) end end +# Note that we cannot use `MacroTools.isdef` because +# of https://github.com/FluxML/MacroTools.jl/issues/154. """ isfuncdef(expr) From c4c6412a571b33fee41db6008aad3b8c8aba428d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 12:34:06 +0000 Subject: [PATCH 198/216] remove unnecessary deprecation --- src/DynamicPPL.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c054ea747..fec6d1d2a 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -176,8 +176,4 @@ include("test_utils.jl") @deprecate acclogp!(vi, logp) acclogp!!(vi, logp) @deprecate resetlogp!(vi) resetlogp!!(vi) -@deprecate initialize_parameters!(vi, init_params, spl) initialize_parameters!!( - vi, init_params, spl -) - end # module From 791c181d6e2c3bdaa7087f6b9c4ee6298cccd108 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 12:36:03 +0000 Subject: [PATCH 199/216] udpated some docstrings --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a033b352e..915aba9d2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -367,7 +367,7 @@ end dot_tilde_assume!!(context, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value. +model inputs), accumulate the log probability, and return the sampled value and updated `vi`. Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ @@ -583,7 +583,7 @@ end dot_tilde_observe!!(context, right, left, vname, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value. +accumulate the log probability, and return the observed value and updated `vi`. Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. From 49165b65292aec2dd3a269e5c3a294fbe9511508 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 12:37:39 +0000 Subject: [PATCH 200/216] fixed more docstrings --- src/context_implementations.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 915aba9d2..20c4af446 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -108,7 +108,7 @@ end tilde_assume!!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value and the update `vi`. +accumulate the log probability, and return the sampled value and updated `vi`. By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. @@ -161,7 +161,7 @@ end tilde_observe!!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the observed value and the updated `vi`. +accumulate the log probability, and return the observed value and updated `vi`. Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. @@ -596,7 +596,7 @@ end dot_tilde_observe!!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value. +probability, and return the observed value and updated `vi`. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ From 987d7ea8435a06f4f223876b977cfc5f96db63b0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 22:01:59 +0000 Subject: [PATCH 201/216] make overloads of BangBang methods qualified --- src/DynamicPPL.jl | 2 +- src/simple_varinfo.jl | 28 ++++++++++++++++------------ src/threadsafe.jl | 20 +++++++++----------- src/varinfo.jl | 32 ++++++++++++++++++++------------ 4 files changed, 46 insertions(+), 36 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index fec6d1d2a..82df0f008 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -32,7 +32,7 @@ import Base: keys, haskey -import BangBang: push!!, empty!!, setindex!! +using BangBang: push!!, empty!!, setindex!! # VarInfo export AbstractVarInfo, diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 0cd86a7e6..52e45cdde 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -153,6 +153,10 @@ function SimpleVarInfo{T}( return SimpleVarInfo(values, convert(T, getlogp(vi))) end +function BangBang.empty!!(vi::SimpleVarInfo) + Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) +end + getlogp(vi::SimpleVarInfo) = vi.logp setlogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, logp) acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp) @@ -258,23 +262,23 @@ function hasvalue(dict::AbstractDict, vn::VarName) return canview(child, value) end -function setindex!!(vi::SimpleVarInfo, val, vn::VarName) +function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. return SimpleVarInfo(set!!(vi.values, vn, val), vi.logp) end # TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with # same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) +function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) for (vn, val) in zip(vns, vals) - vi = setindex!!(vi, val, vn) + vi = BangBang.setindex!!(vi, val, vn) end return vi end -function setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) +function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values(vi) + dict = values_as(vi) # Attempt to split into `parent` and `child` lenses. parent, child, issuccess = splitlens(getlens(vn)) do lens l = lens === nothing ? Setfield.IdentityLens() : lens @@ -285,11 +289,11 @@ function setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) dict_new = if !issuccess # Split doesn't exist ⟹ we're working with a new key. - setindex!!(dict, val, vn) + BangBang.setindex!!(dict, val, vn) else # Split exists ⟹ trying to set an existing key. vn_key = VarName(vn, keylens) - setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) + BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) end return SimpleVarInfo(dict_new, vi.logp) end @@ -302,7 +306,7 @@ function Base.eltype( end # `NamedTuple` -function push!!( +function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym,Setfield.IdentityLens}, value, @@ -311,7 +315,7 @@ function push!!( ) where {sym} return Setfield.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) end -function push!!( +function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, @@ -322,7 +326,7 @@ function push!!( end # `Dict` -function push!!( +function BangBang.push!!( vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, r, @@ -351,7 +355,7 @@ function assume( vi::SimpleOrThreadSafeSimple, ) value = init(rng, dist, sampler) - vi = push!!(vi, vn, value, dist, sampler) + vi = BangBang.push!!(vi, vn, value, dist, sampler) return value, Distributions.loglikelihood(dist, value), vi end @@ -402,7 +406,7 @@ function dot_assume( ) f = (vn, dist) -> init(rng, dist, spl) value = f.(vns, dists) - vi = setindex!!(vi, value, vns) + vi = BangBang.setindex!!(vi, value, vns) lp = sum(Distributions.logpdf.(dists, value)) return value, lp, vi end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c0e8e64c4..c21f03e7d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -65,14 +65,14 @@ getindex(vi::ThreadSafeVarInfo, spl::SampleFromUniform) = getindex(vi.varinfo, s getindex(vi::ThreadSafeVarInfo, vn::VarName) = getindex(vi.varinfo, vn) getindex(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) = getindex(vi.varinfo, vns) -function setindex!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::AbstractSampler) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end -function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromPrior) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end -function setindex!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) - return setindex!(vi.varinfo, val, spl) +function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) @@ -80,13 +80,11 @@ function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function empty!!(vi::ThreadSafeVarInfo) - Setfield.@set! vi.varinfo = empty!!(vi.varinfo) - fill!(vi.logps, zero(getlogp(vi))) - return vi +function BangBang.empty!!(vi::ThreadSafeVarInfo) + return resetlogp!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo))) end -function push!!( +function BangBang.push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) return Setfield.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist, gidset) diff --git a/src/varinfo.jl b/src/varinfo.jl index b94998e9e..9f0caa239 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -595,7 +595,7 @@ zeros. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ -function empty!!(vi::VarInfo) +function BangBang.empty!!(vi::VarInfo) _empty!(vi.metadata) resetlogp!!(vi) reset_num_produce!(vi) @@ -956,7 +956,9 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -setindex!!(vi::AbstractVarInfo, val, vn::VarName) = (setindex!(vi, val, vn); return vi) +function BangBang.setindex!!(vi::AbstractVarInfo, val, vn::VarName) + return (setindex!(vi, val, vn); return vi) +end """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) @@ -974,7 +976,7 @@ function setindex!(vi::TypedVarInfo, val, spl::Sampler) return nothing end -function setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) +function BangBang.setindex!!(vi::AbstractVarInfo, val, spl::AbstractSampler) setindex!(vi, val, spl) return vi end @@ -1100,8 +1102,8 @@ end Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to the `VarInfo` `vi`, mutating if it makes sense. """ -function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!!(vi, vn, r, dist, Set{Selector}([])) +function BangBang.push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return BangBang.push!!(vi, vn, r, dist, Set{Selector}([])) end """ @@ -1112,13 +1114,15 @@ from a distribution `dist` to `VarInfo` `vi`, if it makes sense. The sampler is passed here to invalidate its cache where defined. """ -function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - return push!!(vi, vn, r, dist, spl.selector) +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler +) + return BangBang.push!!(vi, vn, r, dist, spl.selector) end -function push!!( +function BangBang.push!!( vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler ) - return push!!(vi, vn, r, dist) + return BangBang.push!!(vi, vn, r, dist) end """ @@ -1127,10 +1131,14 @@ end Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!!(vi, vn, r, dist, Set([gid])) +function BangBang.push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) + return BangBang.push!!(vi, vn, r, dist, Set([gid])) end -function push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function BangBang.push!!( + vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} +) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo From d514b9902ee9732ef4d7366f3d4951c660338705 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 22:02:13 +0000 Subject: [PATCH 202/216] remove overloading of values and instead use values_as without the type specified --- src/simple_varinfo.jl | 8 ++++---- src/varinfo.jl | 5 +++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 52e45cdde..795a57791 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -167,9 +167,6 @@ acclogp!!(vi::SimpleVarInfo, logp) = SimpleVarInfo(vi.values, getlogp(vi) + logp Return an iterator of keys present in `vi`. """ Base.keys(vi::SimpleVarInfo) = keys(vi.values) -# TODO: Is this really the "right" thing to do? -# Is there a better function name we can use? -Base.values(vi::SimpleVarInfo) = vi.values function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp) vi.logp[] = logp @@ -417,10 +414,13 @@ settrans!(vi::SimpleOrThreadSafeSimple, trans::Bool, vn::VarName) = nothing istrans(::SimpleVarInfo, vn::VarName) = false """ - values(varinfo, Type) + values_as(varinfo[, Type]) Return the values/realizations in `varinfo` as `Type`, if implemented. + +If no `Type` is provided, return values as stored in `varinfo`. """ +values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values diff --git a/src/varinfo.jl b/src/varinfo.jl index 9f0caa239..9ce0414d6 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1502,6 +1502,11 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) return indices end +""" + values_as(vi::AbstractVarInfo) +""" +values_as(vi::VarInfo) = vi.metadata + """ values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) values_as(vi::AbstractVarInfo, ::Type{Dict}) From 0eb27b6278eca761f4d01f81546366261d7fadcc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 22:27:45 +0000 Subject: [PATCH 203/216] Apply suggestions from code review Co-authored-by: David Widmann --- src/test_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index e9b0e6a7d..9ffbb230e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -116,7 +116,7 @@ function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return loglikelihood(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -132,7 +132,7 @@ function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) return logpdf(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) + return logpdf(MvNormal(m, 0.25 * I), model.args.x) end @model function demo_assume_observe_literal() @@ -163,7 +163,7 @@ function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal return loglikelihood(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) + return logpdf(MvNormal(m, 0.25 * I), fill(10.0, length(m))) end @model function demo_assume_literal_dot_observe() From 0bab3e6809853f5d287312f9e9bc60f550a80d0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 22:28:24 +0000 Subject: [PATCH 204/216] renamed hasvalue for SimpleVarInfo to _haskey --- src/simple_varinfo.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 795a57791..281a638b9 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -224,10 +224,8 @@ Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values # TODO: Should we do better? Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values -haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -# TODO: Is `hasvalue` really the right function here? -function hasvalue(nt::NamedTuple, vn::VarName) +Base.haskey(vi::SimpleVarInfo, vn::VarName) = _haskey(vi.values, vn) +function _haskey(nt::NamedTuple, vn::VarName) # LHS: Ensure that `nt` indeed has the property we want. # RHS: Ensure that the lens can view into `nt`. sym = getsym(vn) @@ -236,7 +234,7 @@ end # For `dictlike` we need to check wether `vn` is "immediately" present, or # if some ancestor of `vn` is present in `dictlike`. -function hasvalue(dict::AbstractDict, vn::VarName) +function _haskey(dict::AbstractDict, vn::VarName) # First we check if `vn` is present as is. haskey(dict, vn) && return true From fa3f430f0f3fa71700467d41d70c80ce4562b827 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 6 Dec 2021 22:32:52 +0000 Subject: [PATCH 205/216] revert changes from previous commit --- src/test_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 9ffbb230e..e9b0e6a7d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -116,7 +116,7 @@ function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) return loglikelihood(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end # Using vector of `length` 1 here so the posterior of `m` is the same @@ -132,7 +132,7 @@ function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) return logpdf(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) + return sum(logpdf.(Normal.(m, 0.5), model.args.x)) end @model function demo_assume_observe_literal() @@ -163,7 +163,7 @@ function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal return loglikelihood(Normal(), m) end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return logpdf(MvNormal(m, 0.25 * I), fill(10.0, length(m))) + return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) end @model function demo_assume_literal_dot_observe() From b9a987af2f01672b42089c09d8a2192485abdfe1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 7 Dec 2021 23:49:41 +0000 Subject: [PATCH 206/216] minor version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 20985b109..844289af5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.16.2" +version = "0.17.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 56a5106e09002d33e84fd090215175149445692c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 01:14:39 +0000 Subject: [PATCH 207/216] fixed sampling with ThreadSafeVarInfo --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index db7e1e265..682814ad1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -451,7 +451,7 @@ See also: [`evaluate_threadunsafe!!`](@ref) function evaluate_threadsafe!!(model, varinfo, context) wrapper = ThreadSafeVarInfo(resetlogp!!(varinfo)) result, wrapper_new = _evaluate!!(model, wrapper, context) - return result, setlogp!!(varinfo, getlogp(wrapper_new)) + return result, setlogp!!(wrapper_new.varinfo, getlogp(wrapper_new)) end """ From e00498a6c1b4ee10d340690b99ddec35a7cfcadf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 01:14:53 +0000 Subject: [PATCH 208/216] fixed setindex!! for ThreadSafeVarInfo --- src/threadsafe.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c21f03e7d..6f020a352 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -75,6 +75,13 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, val, spl::SampleFromUniform) return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, val, spl) end +function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) +end +function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<:VarName}) + return Setfield.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) +end + function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) return set_retained_vns_del_by_spl!(vi.varinfo, spl) end From 64945e64383bdb85db8207fb34c38c857b5b3986 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 01:15:03 +0000 Subject: [PATCH 209/216] fixed eltype for ThreadSafeVarInfo wrapping a SimpleVarInfo --- src/simple_varinfo.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 281a638b9..5cecda4b2 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -293,13 +293,6 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName return SimpleVarInfo(dict_new, vi.logp) end -# Necessary for `matchingvalue` to work properly. -function Base.eltype( - vi::SimpleVarInfo{<:Any,T}, spl::Union{AbstractSampler,SampleFromPrior} -) where {T} - return T -end - # `NamedTuple` function BangBang.push!!( vi::SimpleVarInfo{<:NamedTuple}, @@ -332,10 +325,17 @@ function BangBang.push!!( return vi end -const SimpleOrThreadSafeSimple{T} = Union{ - SimpleVarInfo{T},ThreadSafeVarInfo{<:SimpleVarInfo{T}} +const SimpleOrThreadSafeSimple{T,V} = Union{ + SimpleVarInfo{T,V},ThreadSafeVarInfo{<:SimpleVarInfo{T,V}} } +# Necessary for `matchingvalue` to work properly. +function Base.eltype( + vi::SimpleOrThreadSafeSimple{<:Any,V}, spl::Union{AbstractSampler,SampleFromPrior} +) where {V} + return V +end + # Context implementations function assume(dist::Distribution, vn::VarName, vi::SimpleOrThreadSafeSimple) left = vi[vn] From 640551b183fc244c2b225487451578985ec8589a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 01:15:19 +0000 Subject: [PATCH 210/216] fixed a test --- test/compiler.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/compiler.jl b/test/compiler.jl index 80fab9669..1a073edb9 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -551,7 +551,12 @@ end @model demo() = return __varinfo__ retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext()) @test svi == SimpleVarInfo() - @test retval == svi + if Threads.nthreads() > 1 + @test retval isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo} + @test retval.varinfo == svi + else + @test retval == svi + end # We should not be altering return-values other than at top-level. @model function demo() From 00dfdbdbd9e245398b6f571e35a64b5f008bf7d4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 03:37:17 +0000 Subject: [PATCH 211/216] relax atol in serialization tests a bit --- test/serialization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/serialization.jl b/test/serialization.jl index 2f2bf2a2b..7ea81e410 100644 --- a/test/serialization.jl +++ b/test/serialization.jl @@ -10,7 +10,7 @@ samples_s = first.(samples) samples_m = last.(samples) - @test mean(samples_s) ≈ 3 atol = 0.1 + @test mean(samples_s) ≈ 3 atol = 0.2 @test mean(samples_m) ≈ 0 atol = 0.1 end @testset "pmap" begin From db1b03336630614fca6e5932721ea691fa00b32d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 03:53:14 +0000 Subject: [PATCH 212/216] temporarily disable Julia 1.3 --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2cdde32d0..84799298f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,7 @@ jobs: strategy: matrix: version: - - '1.3' # minimum supported version + # - '1.3' # minimum supported version - '1' # current stable version os: - ubuntu-latest From 6a63e152535143bb34c0f57354cf396b1042648f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 04:09:38 +0000 Subject: [PATCH 213/216] relax atol for a prior check --- test/sampler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sampler.jl b/test/sampler.jl index b32ba29d6..ffd0182e8 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -18,7 +18,7 @@ @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. - @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.1 + @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 chains = sample(model, SampleFromUniform(), N; progress=false) @test chains isa Vector{<:VarInfo} From d65fc7445fb251fc8c769a8d106ec7e4f54200dd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 8 Dec 2021 19:31:26 +0000 Subject: [PATCH 214/216] Improvements to `@submodel` in #309 (#348) * added prefix keyword argument to submodel-macro * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * converted example in docs into test * fixed docstring * Apply suggestions from code review Co-authored-by: Philipp Gabler * removed redundant prefix_submodel_context def and added another example to docstring * fixed doctests * attempt at fixing doctests * another attempt at fixing doctests * had a typo in docstring Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Philipp Gabler --- src/model.jl | 2 +- src/submodel_macro.jl | 139 ++++++++++++++++++++++++++++++++++++++---- test/compiler.jl | 2 +- 3 files changed, 128 insertions(+), 15 deletions(-) diff --git a/src/model.jl b/src/model.jl index 682814ad1..702d76a17 100644 --- a/src/model.jl +++ b/src/model.jl @@ -215,7 +215,7 @@ But one needs to be careful when prefixing variables in the nested models: ```jldoctest condition julia> @model function demo_outer_prefix() - @submodel inner m = demo_inner() + @submodel prefix="inner" m = demo_inner() return m end demo_outer_prefix (generic function with 2 methods) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 54477a05d..009a12c7b 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,5 +1,6 @@ """ @submodel model + @submodel ... = model Run a Turing `model` nested inside of a Turing model. @@ -44,22 +45,32 @@ true ``` """ macro submodel(expr) - return submodel(expr) + return submodel(:(prefix = false), expr) end """ - @submodel prefix model + @submodel prefix=... model + @submodel prefix=... ... = model Run a Turing `model` nested inside of a Turing model and add "`prefix`." as a prefix to all random variables inside of the `model`. +Valid expressions for `prefix=...` are: +- `prefix=false`: no prefix is used. +- `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side + `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. +- `prefix="my prefix"`: prefix is taken to be the static string "my prefix". +- `prefix=expression`: `expression` is evaluated at runtime, resulting in + the prefix `Symbol(expression)`. Note that this also includes string-interpolation, + e.g. `prefix="x[\$i]"` as it requires runtime information. + The prefix makes it possible to run the same Turing model multiple times while keeping track of all random variables correctly. The return value can be assigned to a variable. # Examples - +## Example models ```jldoctest submodelprefix; setup=:(using Distributions) julia> @model function demo1(x) x ~ Normal() @@ -67,8 +78,8 @@ julia> @model function demo1(x) end; julia> @model function demo2(x, y, z) - @submodel sub1 a = demo1(x) - @submodel sub2 b = demo1(y) + @submodel prefix="sub1" a = demo1(x) + @submodel prefix="sub2" b = demo1(y) return z ~ Uniform(-a, b) end; ``` @@ -109,27 +120,129 @@ julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); julia> getlogp(vi) ≈ logprior + loglikelihood true ``` + +## Different ways of setting the prefix +```jldoctest submodel-prefix-alternatives; setup=:(using DynamicPPL, Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> # When `prefix` is unspecified, no prefix is used. + @model outer() = @submodel a = inner() +outer (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(outer())) +true + +julia> # Explicitely don't use any prefix. + @model outer() = @submodel prefix=false a = inner() +outer (generic function with 2 methods) + +julia> @varname(x) in keys(VarInfo(outer())) +true + +julia> # Automatically determined from `a`. + @model outer() = @submodel prefix=true a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"a.x") in keys(VarInfo(outer())) +true + +julia> # Using a static string. + @model outer() = @submodel prefix="my prefix" a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"my prefix.x") in keys(VarInfo(outer())) +true + +julia> # Using string interpolation. + @model outer() = @submodel prefix="\$(inner().name)" a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"inner.x") in keys(VarInfo(outer())) +true + +julia> # Or using some arbitrary expression. + @model outer() = @submodel prefix=1 + 2 a = inner() +outer (generic function with 2 methods) + +julia> @varname(var"3.x") in keys(VarInfo(outer())) +true + +julia> # (×) Automatic prefixing without a left-hand side expression does not work! + @model outer() = @submodel prefix=true inner() +ERROR: LoadError: cannot automatically prefix with no left-hand side +[...] +``` + +# Notes +- The choice `prefix=expression` means that the prefixing will incur a runtime cost. + This is also the case for `prefix=true`, depending on whether the expression on the + the right-hand side of `... = model` requires runtime-information or not, e.g. + `x = model` will result in the _static_ prefix `x`, while `x[i] = model` will be + resolved at runtime. """ -macro submodel(prefix, expr) - ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) - return submodel(expr, ctx) +macro submodel(prefix_expr, expr) + return submodel(prefix_expr, expr, esc(:__context__)) +end + +# Automatic prefixing. +function prefix_submodel_context(prefix::Bool, left::Symbol, ctx) + return prefix ? prefix_submodel_context(left, ctx) : ctx +end + +function prefix_submodel_context(prefix::Bool, left::Expr, ctx) + return prefix ? prefix_submodel_context(varname(left), ctx) : ctx +end + +# Manual prefixing. +prefix_submodel_context(prefix, left, ctx) = prefix_submodel_context(prefix, ctx) +function prefix_submodel_context(prefix, ctx) + # E.g. `prefix="asd[$i]"` or `prefix=asd` with `asd` to be evaluated. + return :($(DynamicPPL.PrefixContext){$(Symbol)($(esc(prefix)))}($ctx)) end -function submodel(expr, ctx=esc(:__context__)) +function prefix_submodel_context(prefix::Union{AbstractString,Symbol}, ctx) + # E.g. `prefix="asd"`. + return :($(DynamicPPL.PrefixContext){$(esc(Meta.quot(Symbol(prefix))))}($ctx)) +end + +function prefix_submodel_context(prefix::Bool, ctx) + if prefix + error("cannot automatically prefix with no left-hand side") + end + + return ctx +end + +function submodel(prefix_expr, expr, ctx=esc(:__context__)) + prefix_left, prefix = getargs_assignment(prefix_expr) + if prefix_left !== :prefix + error("$(prefix_left) is not a valid kwarg") + end + # `prefix=false` => don't prefix, i.e. do nothing to `ctx`. + # `prefix=true` => automatically determine prefix. + # `prefix=...` => use it. args_assign = getargs_assignment(expr) return if args_assign === nothing + ctx = prefix_submodel_context(prefix, ctx) # In this case we only want to get the `__varinfo__`. quote $(esc(:__varinfo__)) = last( - _evaluate!!($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) + $(DynamicPPL._evaluate!!)($(esc(expr)), $(esc(:__varinfo__)), $(ctx)) ) end else - # Here we also want the return-variable. - # TODO: Should we prefix by `L` by default? L, R = args_assign + # Now that we have `L` and `R`, we can prefix automagically. + try + ctx = prefix_submodel_context(prefix, L, ctx) + catch e + error( + "failed to determine prefix from $(L); please specify prefix using the `@submodel prefix=\"your prefix\" ...` syntax", + ) + end quote - $(esc(L)), $(esc(:__varinfo__)) = _evaluate!!( + $(esc(L)), $(esc(:__varinfo__)) = $(DynamicPPL._evaluate!!)( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) end diff --git a/test/compiler.jl b/test/compiler.jl index 1a073edb9..7860aa2ca 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -455,7 +455,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - @submodel $(Symbol("ar1_$i")) x = AR1(num_steps, α, μ, σ) + @submodel prefix = "ar1_$i" x = AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end From a8c368c330dc456b1666b1fe11a1853c07b958b3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 11 Dec 2021 05:14:13 +0000 Subject: [PATCH 215/216] fixed a test case using submodel --- test/compiler.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 7860aa2ca..4c76cf1ab 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -420,8 +420,8 @@ end end @model function demo_useval(x, y) - @submodel sub1 x1 = demo_return(x) - @submodel sub2 x2 = demo_return(y) + @submodel prefix = "sub1" x1 = demo_return(x) + @submodel prefix = "sub2" x2 = demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end From 64819a1597c984a53cf6c4d2de0a092b753e9942 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 11 Dec 2021 05:34:55 +0000 Subject: [PATCH 216/216] improved docstring according to comments by @devmotion --- src/submodel_macro.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 009a12c7b..5ffed3c42 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -4,8 +4,6 @@ Run a Turing `model` nested inside of a Turing model. -The return value can be assigned to a variable. - # Examples ```jldoctest submodel; setup=:(using Distributions) @@ -59,16 +57,11 @@ Valid expressions for `prefix=...` are: - `prefix=false`: no prefix is used. - `prefix=true`: _attempt_ to automatically determine the prefix from the left-hand side `... = model` by first converting into a `VarName`, and then calling `Symbol` on this. -- `prefix="my prefix"`: prefix is taken to be the static string "my prefix". -- `prefix=expression`: `expression` is evaluated at runtime, resulting in - the prefix `Symbol(expression)`. Note that this also includes string-interpolation, - e.g. `prefix="x[\$i]"` as it requires runtime information. +- `prefix=expression`: results in the prefix `Symbol(expression)`. The prefix makes it possible to run the same Turing model multiple times while keeping track of all random variables correctly. -The return value can be assigned to a variable. - # Examples ## Example models ```jldoctest submodelprefix; setup=:(using Distributions)