Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vectorize method for LKJCholesky #485

Merged
merged 30 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3cccca7
using `LinearAlgebra.Cholesky`
harisorgn Jun 16, 2023
1c0e8c4
add `vectorize` for `LKJCholesky`
harisorgn Jun 16, 2023
e9df34b
add `vectorize` test
harisorgn Jun 16, 2023
0a4a2ae
add forgotten `end`
harisorgn Jun 16, 2023
0be59eb
Update test/utils.jl
harisorgn Jun 16, 2023
044130c
fix typo
harisorgn Jun 19, 2023
20b8d19
add `reconstruct` methods for LKJ/LKJCholesky inv bijectors
harisorgn Jun 19, 2023
391c8a1
bump patch
harisorgn Jun 19, 2023
9fec706
bump Bijectors compat
harisorgn Jun 19, 2023
181d37e
Update src/utils.jl
harisorgn Jun 19, 2023
580a57a
Merge branch 'master' into ho/vec_lkjcholesky
harisorgn Jun 20, 2023
9085c64
add Bijectors v0.13 compat
harisorgn Jun 20, 2023
385d640
Merge branch 'master' into ho/vec_lkjcholesky
harisorgn Jun 21, 2023
b67a9aa
Merge branch 'master' into ho/vec_lkjcholesky
harisorgn Jun 21, 2023
397c8dd
add `inittrans` method for `CholeskyVariate`
harisorgn Jun 22, 2023
0ab6212
add `LKJ`/`LKJCholesky` tests
harisorgn Jun 22, 2023
c1f30dc
include tests
harisorgn Jun 22, 2023
94b6f10
Update test/lkj.jl
harisorgn Jun 22, 2023
38f9412
Update test/lkj.jl
harisorgn Jun 22, 2023
faaca3b
make tests more accurate
harisorgn Jun 23, 2023
a914a78
Update test/lkj.jl
harisorgn Jun 23, 2023
980141e
Update test/lkj.jl
harisorgn Jun 23, 2023
1919465
Update test/lkj.jl
harisorgn Jun 23, 2023
6470f58
Update test/lkj.jl
harisorgn Jun 23, 2023
f55e7d1
Update test/lkj.jl
harisorgn Jun 23, 2023
bf38f59
Update test/lkj.jl
harisorgn Jun 23, 2023
5056342
Update test/lkj.jl
harisorgn Jun 23, 2023
f5a8113
test `LKJCholesky` for both `'U'` and `'L'`
harisorgn Jun 23, 2023
5e0a696
remove unnecessary `float` wrap
harisorgn Jun 23, 2023
c5b9ad9
Update test/lkj.jl
harisorgn Jun 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.0"
version = "0.23.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky

using DocStringExtensions

using Random: Random
Expand Down
2 changes: 1 addition & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ end
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)

Return linked `val` but reconstruct before linking, if necessary.

Expand Down
15 changes: 14 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))

# NOTE:
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
Expand All @@ -235,6 +236,13 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
function reconstruct(
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
)
return copy(val)
end

# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
Expand Down Expand Up @@ -294,7 +302,12 @@ function inittrans(rng, dist::MatrixDistribution)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end

function inittrans(rng, dist::Distribution{CholeskyVariate})
# Get the size of the unconstrained vector
b = link_transform(dist)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end
################################
# Multi-sample initialisations #
################################
Expand Down
55 changes: 55 additions & 0 deletions test/lkj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using Bijectors: pd_from_upper, pd_from_lower

function pd_from_triangular(X::AbstractMatrix, uplo::Char)
return uplo == 'U' ? pd_from_upper(X) : pd_from_lower(X)
end

@model lkj_prior_demo() = x ~ LKJ(2, 1)
@model lkj_chol_prior_demo(uplo) = x ~ LKJCholesky(2, 1, uplo)

# Same for both distributions
target_mean = vec(Matrix{Float64}(I, 2, 2))

_lkj_atol = 0.05

@testset "Sample from x ~ LKJ(2, 1)" begin
model = lkj_prior_demo()
# `SampleFromPrior` will sample in constrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol =
_lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol =
_lkj_atol
end
end

@testset "Sample from x ~ LKJCholesky(2, 1, $(uplo))" for uplo in ['U', 'L']
model = lkj_chol_prior_demo(uplo)
# `SampleFromPrior` will sample in unconstrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
pd_from_triangular(M, uplo)
end
@test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = reshape(s.metadata.vals, (2, 2))
pd_from_triangular(M, uplo)
end
@test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ include("test_util.jl")
include("serialization.jl")

include("loglikelihoods.jl")

include("lkj.jl")
end

@testset "compat" begin
Expand Down
6 changes: 6 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "vectorize" begin
dist = LKJCholesky(2, 1)
x = rand(dist)
@test vectorize(dist, x) == vec(x.UL)
end
end