Skip to content

Commit

Permalink
Proper support for distributions with embedded support (#462)
Browse files Browse the repository at this point in the history
Attempt at addressing #461.

I think the approach here is somewhat correct, but it's currently very dirty because of the intricacies of `VarInfo` implementation. This can be cleaned up, but will take some effort. Until I've done this, I leave this as a draft PR.

We also most certainly need to do some benchmarking before merging this as it could lead to some additional overhead.

NOTE: This is based on #457 

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Apr 26, 2023
1 parent 4d6b177 commit f4cc443
Show file tree
Hide file tree
Showing 12 changed files with 309 additions and 71 deletions.
3 changes: 2 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
```
DynamicPPL.reconstruct
```

#### Utils

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export AbstractVarInfo,
push!!,
empty!!,
getlogp,
resetlogp!,
resetlogp!!,
setlogp!!,
acclogp!!,
resetlogp!!,
Expand Down
93 changes: 93 additions & 0 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,99 @@ variables `x` would return
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
value is reconstructed to the correct type and shape according to `dist`.
"""
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
x_recon = reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# TODO: Once we `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
Return linked `val` but reconstruct before linking, if necessary.
Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
return a reconstructed value, i.e. a value of the same type and shape as expected
by `dist`.
See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
"""
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
return reconstruct_and_link(dist, val)
end

"""
invlink_and_reconstruct(dist, val)
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return invlinked and reconstructed `val`.
See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
"""
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
function invlink_and_reconstruct(dist, val)
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
end
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
return invlink_and_reconstruct(dist, val)
end

"""
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
"""
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
reconstruct_and_link(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
"""
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
invlink_and_reconstruct(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])
Invlink `x` and compute the logpdf under `dist` including correction from
the invlink-transformation.
If `x` is not provided, `getval(vi, vn)` will be used.
"""
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
end
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
# NOTE: Will this cause type-instabilities or will union-splitting save us?
f = istrans(vi, vn) ? invlink_transform(dist) : identity
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
return x, logpdf(dist, x) + logjac
end

# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
Expand Down
30 changes: 21 additions & 9 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
r, logp = invlink_with_logpdf(vi, vn, dist)
return r, logp, vi
end

# SampleFromPrior and SampleFromUniform
Expand All @@ -211,7 +211,9 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
BangBang.setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
Expand All @@ -220,15 +222,17 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
push!!(vi, vn, link(dist, r), dist, sampler)
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
end

# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
Expand Down Expand Up @@ -470,7 +474,11 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
setindex!!(
vi,
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
vn,
)
setorder!(vi, vn, get_num_produce(vi))
end
else
Expand Down Expand Up @@ -508,13 +516,17 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
)
setorder!(vi, vn, get_num_produce(vi))
end
else
# r = reshape(vi[vec(vns)], size(vns))
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
# and fix the lines below.
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -525,7 +537,7 @@ function get_and_set_val!(
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
if istrans(vi)
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
Expand Down
13 changes: 8 additions & 5 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end

# `NamedTuple`
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
end
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
vals_linked = mapreduce(vcat, vns) do vn
Expand Down Expand Up @@ -329,6 +329,9 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
return reconstruct(dist, vals, length(vns))
end

# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)

Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)

function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
Expand Down Expand Up @@ -426,7 +429,7 @@ function assume(
)
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
value_raw = maybe_link(vi, vn, dist, value)
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
end
Expand All @@ -444,9 +447,9 @@ function dot_assume(

# Transform if we're working in transformed space.
value_raw = if dists isa Distribution
maybe_link.((vi,), vns, (dists,), value)
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
else
maybe_link.((vi,), vns, dists, value)
maybe_reconstruct_and_link.((vi,), vns, dists, value)
end

# Update `vi`
Expand All @@ -473,7 +476,7 @@ function dot_assume(

# Update `vi`.
for (vn, val) in zip(vns, eachcol(value))
val_linked = maybe_link(vi, vn, dist, val)
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
vi = BangBang.setindex!!(vi, val_linked, vn)
end

Expand Down
2 changes: 2 additions & 0 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,5 @@ end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
6 changes: 3 additions & 3 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function tilde_assume(

# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
r_transformed = isinverse ? r : bijector(right)(r)
r_transformed = isinverse ? r : link_transform(right)(r)
return r, lp, setindex!!(vi, r_transformed, vn)
end

Expand All @@ -27,7 +27,7 @@ function dot_tilde_assume(
vi,
) where {isinverse}
r = getindex.((vi,), vns, (dist,))
b = bijector(dist)
b = link_transform(dist)

is_trans_uniques = unique(istrans.((vi,), vns))
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
Expand Down Expand Up @@ -70,7 +70,7 @@ function dot_tilde_assume(
@assert !isinverse "Trying to invlink non-transformed variables"
end

b = bijector(dist)
b = link_transform(dist)
for (vn, ri) in zip(vns, eachcol(r))
# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
Expand Down
47 changes: 46 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,39 @@ function to_namedtuple_expr(syms, vals=syms)
return :(NamedTuple{$names_expr}($vals_expr))
end

"""
link_transform(dist)
Returns the constrained-to-unconstrained bijector for distribution `dist`.
By default, this is just `Bijectors.bijector(dist)`.
!!! warning
Note that this is not current used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
link_transform(dist) = bijector(dist)

"""
invlink_transform(dist)
Returns the unconstrained-to-constrained bijector for distribution `dist`.
By default, this is just `inverse(link_transform(dist))`.
!!! warning
Note that this is not current used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
invlink_transform(dist) = inverse(link_transform(dist))

#####################################################
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
Expand All @@ -183,7 +212,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
# otherwise we will have error for MatrixDistribution.
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
# support for some types related to matrices (like PDMat).
reconstruct(d::UnivariateDistribution, val::Real) = val

"""
reconstruct([f, ]dist, val)
Reconstruct `val` so that it's compatible with `dist`.
If `f` is also provided, the reconstruct value will be
such that `f(reconstruct_val)` is compatible with `dist`.
"""
reconstruct(f, dist, val) = reconstruct(dist, val)

# No-op versions.
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = val
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = val
# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)
Expand Down
Loading

0 comments on commit f4cc443

Please sign in to comment.