Skip to content

Commit

Permalink
Proper support for distributions with embedded support (#462)
Browse files Browse the repository at this point in the history
* compat with new Bijectors.jl

* bump compat bounds for Bijectors and make it a breaking change

* remove mentioning of Exp and Identity in test_utils.jl

* added mistakenly commented out tests

* fixed test_utils

* bump bijectors version

* added no-op impls for reconstruct

* added a bunch of convenience methods for working with Metadata instead
of VarInfo

* added usage of _inner_transform! in link, in addition to additional
methods for linking and invlinking

* updated getall to not assume we want all the values in metadata

* added FIXME comment

* fixed typo in comment

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* sligh simplification of the linking stuff

* formatting

* lower bound test compat entry for Tracker

* move link-related functions to abstract_varinfo.jl and renamed methods
to be more descriptive

* fixed invlink!! for VarInfo

* fixed link and invlink tests

* added specialized mapreduce for (named)tuples to improve type-inference

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added missing docstring

* added minor TODO comment for the future

* added `link_transform` and `invlink_transform`, basically equivalent
to `bijector`  but allows us to separate the choices made in DPPL from
those in Bijectors

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/utils.jl

* added some docstrings

* renamed link_and_reconstruct to the more accurate reconstruct_and_link

* removed unnecessary definition of inlink_transform

* fixed bug in newmetadata

* removed mapreduce_tuple in favor of reduce and map

* Update src/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* introduce _logpdf_with_trans as a placeholder while we migrate away
from the usage of this function and into invlink_and_pdf

* reconstruct now takes into account the transformation to be used

* replaced more references to bijector with link_transform

* added docstring for invlink_with_logpdf

* fixed bug in assume introduced hacky getval for SimpleVarInfo

* added tests for linking

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* rename maybe_link_and_reconstruct to maybe_reconstruct_and_link

* added reconstruct to the API docs

* Update docs/src/api.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* removed unnecessary comment

* removed _logpdf_with_trans in favour of just using Bijectors.jl's for now

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added warning regarding overloading to link_transform and invlink_transform

* added missing getval for ThreadSafeVarInfo

* added a minor additional test to linking

* Apply suggestions from code review

Co-authored-by: David Widmann <[email protected]>

* reverted chagnes from previous commit

* fixed usage of deprecated link

* Update test/linking.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/DynamicPPL.jl

* fixed tests

* added copy to tonamedtuple to avoid mutating chain samples

* improved testing for setval! and generated_quantities

* bumped the version in turing tests

* Apply suggestions from code review

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Hong Ge <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
4 people authored Jun 15, 2023
1 parent 5a729e2 commit 7b01d25
Show file tree
Hide file tree
Showing 13 changed files with 331 additions and 175 deletions.
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```
DynamicPPL.reconstruct
```

#### Utils

Expand Down
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ export AbstractVarInfo,
push!!,
empty!!,
getlogp,
resetlogp!,
setlogp!!,
acclogp!!,
resetlogp!!,
Expand Down
93 changes: 93 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,99 @@ variables `x` would return
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
value is reconstructed to the correct type and shape according to `dist`.
"""
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
x_recon = reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
Return linked `val` but reconstruct before linking, if necessary.
Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
return a reconstructed value, i.e. a value of the same type and shape as expected
by `dist`.
See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
"""
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
return reconstruct_and_link(dist, val)
end

"""
invlink_and_reconstruct(dist, val)
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return invlinked and reconstructed `val`.
See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
"""
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
function invlink_and_reconstruct(dist, val)
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
end
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
return invlink_and_reconstruct(dist, val)
end

"""
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
"""
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
reconstruct_and_link(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
"""
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
invlink_and_reconstruct(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])
Invlink `x` and compute the logpdf under `dist` including correction from
the invlink-transformation.
If `x` is not provided, `getval(vi, vn)` will be used.
"""
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
end
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
# NOTE: Will this cause type-instabilities or will union-splitting save us?
f = istrans(vi, vn) ? invlink_transform(dist) : identity
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
return x, logpdf(dist, x) + logjac
end

# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
Expand Down
30 changes: 21 additions & 9 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
r, logp = invlink_with_logpdf(vi, vn, dist)
return r, logp, vi
end

# SampleFromPrior and SampleFromUniform
Expand All @@ -211,7 +211,9 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
BangBang.setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
Expand All @@ -220,15 +222,17 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end

# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
Expand Down Expand Up @@ -470,7 +474,11 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
setindex!!(
vi,
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
vn,
)
setorder!(vi, vn, get_num_produce(vi))
end
else
Expand Down Expand Up @@ -508,13 +516,17 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
)
setorder!(vi, vn, get_num_produce(vi))
end
else
# r = reshape(vi[vec(vns)], size(vns))
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
# and fix the lines below.
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -525,7 +537,7 @@ function get_and_set_val!(
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
if istrans(vi)
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
Expand Down
17 changes: 10 additions & 7 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end

# `NamedTuple`
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
end
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
vals_linked = mapreduce(vcat, vns) do vn
Expand Down Expand Up @@ -329,6 +329,9 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
return reconstruct(dist, vals, length(vns))
end

# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)

Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)

function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
Expand Down Expand Up @@ -426,7 +429,7 @@ function assume(
)
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
value_raw = maybe_link(vi, vn, dist, value)
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
end
Expand All @@ -444,9 +447,9 @@ function dot_assume(

# Transform if we're working in transformed space.
value_raw = if dists isa Distribution
maybe_link.((vi,), vns, (dists,), value)
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
else
maybe_link.((vi,), vns, dists, value)
maybe_reconstruct_and_link.((vi,), vns, dists, value)
end

# Update `vi`
Expand All @@ -473,7 +476,7 @@ function dot_assume(

# Update `vi`.
for (vn, val) in zip(vns, eachcol(value))
val_linked = maybe_link(vi, vn, dist, val)
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
vi = BangBang.setindex!!(vi, val_linked, vn)
end

Expand All @@ -488,7 +491,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)
vals = map(copy Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

Expand All @@ -501,7 +504,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(Base.Fix1(getindex, vi), vns)
vals = map(copy Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))
Expand Down
2 changes: 2 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
6 changes: 3 additions & 3 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function tilde_assume(

# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
r_transformed = isinverse ? r : bijector(right)(r)
r_transformed = isinverse ? r : link_transform(right)(r)
return r, lp, setindex!!(vi, r_transformed, vn)
end

Expand All @@ -27,7 +27,7 @@ function dot_tilde_assume(
vi,
) where {isinverse}
r = getindex.((vi,), vns, (dist,))
b = bijector(dist)
b = link_transform(dist)

is_trans_uniques = unique(istrans.((vi,), vns))
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
Expand Down Expand Up @@ -70,7 +70,7 @@ function dot_tilde_assume(
@assert !isinverse "Trying to invlink non-transformed variables"
end

b = bijector(dist)
b = link_transform(dist)
for (vn, ri) in zip(vns, eachcol(r))
# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
Expand Down
47 changes: 46 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,39 @@ function to_namedtuple_expr(syms, vals)
return :(NamedTuple{$names_expr}($vals_expr))
end

"""
link_transform(dist)
Return the constrained-to-unconstrained bijector for distribution `dist`.
By default, this is just `Bijectors.bijector(dist)`.
!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
link_transform(dist) = bijector(dist)

"""
invlink_transform(dist)
Return the unconstrained-to-constrained bijector for distribution `dist`.
By default, this is just `inverse(link_transform(dist))`.
!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
invlink_transform(dist) = inverse(link_transform(dist))

#####################################################
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
Expand All @@ -191,7 +220,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
# otherwise we will have error for MatrixDistribution.
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
# support for some types related to matrices (like PDMat).
reconstruct(d::UnivariateDistribution, val::Real) = val

"""
reconstruct([f, ]dist, val)
Reconstruct `val` so that it's compatible with `dist`.
If `f` is also provided, the reconstruct value will be
such that `f(reconstruct_val)` is compatible with `dist`.
"""
reconstruct(f, dist, val) = reconstruct(dist, val)

# No-op versions.
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)
Expand Down
Loading

0 comments on commit 7b01d25

Please sign in to comment.