diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 92e00b92..605fe002 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -33,7 +33,7 @@ using Reexport, Requires using LinearAlgebra using MappedArrays using Base.Iterators: drop -using LinearAlgebra: AbstractTriangular +using LinearAlgebra: AbstractTriangular, Hermitian using InverseFunctions: InverseFunctions @@ -145,6 +145,9 @@ function _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractM return logabsdetjac.((bijector(d),), x) end +_logabsdetjac_dist(d::LKJCholesky, x::Cholesky) = logabsdetjac(bijector(d), x) +_logabsdetjac_dist(d::LKJCholesky, x::AbstractVector) = logabsdetjac.((bijector(d),), x) + function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) return pd_logpdf_with_trans(d, x, transform) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 995708dd..9367b0cd 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -65,34 +65,20 @@ struct CorrBijector <: Bijector end with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x) -function transform(b::CorrBijector, x::AbstractMatrix{<:Real}) - w = cholesky(x).U # keep LowerTriangular until here can avoid some computation +function transform(b::CorrBijector, X::AbstractMatrix{<:Real}) + w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation r = _link_chol_lkj(w) - return r + zero(x) + return r + zero(X) # This dense format itself is required by a test, though I can't get the point. # https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67 end function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) w = _inv_link_chol_lkj(y) - return w' * w + return pd_from_upper(w) end -function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) - K = LinearAlgebra.checksquare(y) - - result = float(zero(eltype(y))) - for j in 2:K, i in 1:(j - 1) - @inbounds abs_y_i_j = abs(y[i, j]) - result += - (K - i + 1) * ( - IrrationalConstants.logtwo - - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) - end - - return result -end +logabsdetjac(::Inverse{CorrBijector}, Y::AbstractMatrix{<:Real}) = _logabsdetjac_inv_corr(Y) function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) #= It may be more efficient if we can use un-contraint value to prevent call of b @@ -103,25 +89,232 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) return -logabsdetjac(inverse(b), (b(X))) end -function _inv_link_chol_lkj(y) - K = LinearAlgebra.checksquare(y) +""" + triu_mask(X::AbstractMatrix, k::Int) - w = similar(y) +Return a mask for elements of `X` above the `k`th diagonal. +""" +function triu_mask(X::AbstractMatrix, k::Int) + # Ensure that we're working with a square matrix. + LinearAlgebra.checksquare(X) - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i - 1, j]) - tmp = w[i - 1, j] - w[i - 1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) + # Using `similar` allows us to respect device of array, etc., e.g. `CuArray`. + m = similar(X, Bool) + return triu(.~m .| m, k) +end + +triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)] + +function update_triu_from_vec!( + vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real} +) + # Ensure that we're working with one-based indexing. + # `triu` requires this too. + LinearAlgebra.require_one_based_indexing(X) + + # Set the values. + idx = 1 + m, n = size(X) + for j in 1:n + for i in 1:min(j - k, m) + X[i, j] = vals[idx] + idx += 1 end - for i in (j + 1):K - w[i, j] = 0 + end + + return X +end + +function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int) + X = similar(vals, dim, dim) + # TODO: Do we need this? + X .= 0 + return update_triu_from_vec!(vals, k, X) +end + +function ChainRulesCore.rrule( + ::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int +) + function update_triu_from_vec_pullback(ΔX) + return ( + ChainRulesCore.NoTangent(), + triu_to_vec(ChainRulesCore.unthunk(ΔX), k), + ChainRulesCore.NoTangent(), + ChainRulesCore.NoTangent(), + ) + end + return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback +end + +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 + +""" + triu1_to_vec(X::AbstractMatrix{<:Real}) + +Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector. +""" +triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1) + +inverse(::typeof(triu1_to_vec)) = vec_to_triu1 + +""" + vec_to_triu1(x::AbstractVector{<:Real}) + +Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`. +""" +function vec_to_triu1(x::AbstractVector) + n = _triu1_dim_from_length(length(x)) + X = update_triu_from_vec(x, 1, n) + return upper_triangular(X) +end + +inverse(::typeof(vec_to_triu1)) = triu1_to_vec + +function vec_to_triu1_row_index(idx) + # Assumes that vector was saved in a column-major order + # and that vector is one-based indexed. + M = _triu1_dim_from_length(idx - 1) + return idx - (M * (M - 1) ÷ 2) +end + +""" + VecCorrBijector <: Bijector + +A bijector to transform a correlation matrix to an unconstrained vector. + +# Reference +https://mc-stan.org/docs/reference-manual/correlation-matrix-transform.html + +See also: [`CorrBijector`](@ref) and ['VecCholeskyBijector'](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCorrBijector(); + +julia> X = rand(rng, LKJ(3, 1)) # Sample a correlation matrix. +3×3 Matrix{Float64}: + 1.0 -0.705273 -0.348638 + -0.705273 1.0 0.0534538 + -0.348638 0.0534538 1.0 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> inverse(b)(y) ≈ X # (✓) Round-trip through `b` and its inverse. +true +""" +struct VecCorrBijector <: Bijector end + +with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X)) + +function logabsdetjac(b::VecCorrBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + +function transform(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + return pd_from_upper(_inv_link_chol_lkj(y)) +end + +function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) + return _logabsdetjac_inv_corr(y) +end + +""" + VecCholeskyBijector <: Bijector + +A bijector to transform a Cholesky factor of a correlation matrix to an unconstrained vector. + +# Fields +- mode :`Symbol`. Controls the inverse tranformation : + - if `mode === :U` returns a `LinearAlgebra.Cholesky` holding the `UpperTriangular` factor + - if `mode === :L` returns a `LinearAlgebra.Cholesky` holding the `LowerTriangular` factor + +# Reference +https://mc-stan.org/docs/reference-manual/cholesky-factors-of-correlation-matrices-1 + +See also: [`VecCorrBijector`](@ref) + +# Example + +```jldoctest +julia> using LinearAlgebra + +julia> using StableRNGs; rng = StableRNG(42); + +julia> b = Bijectors.VecCholeskyBijector(:U); + +julia> X = rand(rng, LKJCholesky(3, 1, :U)) # Sample a correlation matrix. +Cholesky{Float64, Matrix{Float64}} +U factor: +3×3 UpperTriangular{Float64, Matrix{Float64}}: + 1.0 0.937494 0.865891 + ⋅ 0.348002 -0.320442 + ⋅ ⋅ 0.384122 + +julia> y = b(X) # Transform to unconstrained vector representation. +3-element Vector{Float64}: + -0.8777149781928181 + -0.3638927608636788 + -0.29813769428942216 + +julia> X_inv = inverse(b)(y); +julia> X_inv.U ≈ X.U # (✓) Round-trip through `b` and its inverse. +true +julia> X_inv.L ≈ X.L # (✓) Also works for the lower triangular factor. +true +""" +struct VecCholeskyBijector <: Bijector + mode::Symbol + function VecCholeskyBijector(uplo) + s = Symbol(uplo) + if (s === :U) || (s === :L) + new(s) + else + throw( + ArgumentError( + "mode must be either :U (upper triangular) or :L (lower triangular)" + ), + ) end end +end + +# TODO: Implement directly to make use of shared computations. +with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X)) + +function logabsdetjac(b::VecCholeskyBijector, x) + return -logabsdetjac(inverse(b), b(x)) +end + +function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) + if b.orig.mode === :U + # This Cholesky constructor is compatible with Julia v1.6 + # for later versions Cholesky(::UpperTriangular) works + return Cholesky(_inv_link_chol_lkj(y), 'U', 0) + else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor. + # HACK: Need to make materialize the transposed matrix to avoid numerical instabilities. + # If we don't, the return-type can be both `Matrix` and `Transposed`. + return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0) + end +end - return w +function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) + return _logabsdetjac_inv_chol(y) end """ @@ -140,7 +333,7 @@ end But this implementation will not work when w[i-1, j] = 0. Though it is a zero measure set, unit matrix initialization will not work. -For equivelence, following explanations is given by @torfjelde: +For equivalence, following explanations is given by @torfjelde: For `(i, j)` in the loop below, we define @@ -156,25 +349,150 @@ and so which is the above implementation. """ -function _link_chol_lkj(w) - K = LinearAlgebra.checksquare(w) +function _link_chol_lkj(W::AbstractMatrix) + K = LinearAlgebra.checksquare(W) - z = similar(w) # z is also UpperTriangular. + z = similar(W) # z is also UpperTriangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. - # This block can't be integrated with loop below, because w[1,1] != 0. - @inbounds z[1, 1] = 0 + # This block can't be integrated with loop below, because W[1,1] != 0. + @inbounds z[:, 1] .= 0 @inbounds for j in 2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) + z[1, j] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) for i in 2:(j - 1) - p = w[i, j] / tmp + p = W[i, j] / tmp tmp *= sqrt(1 - p^2) z[i, j] = atanh(p) end - z[j, j] = 0 + for i in j:K + z[i, j] = 0 + end + end + + return z +end + +function _link_chol_lkj(W::UpperTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters + + z = similar(W, N) + + idx = 1 + @inbounds for j in 2:K + z[idx] = atanh(W[1, j]) + idx += 1 + tmp = sqrt(1 - W[1, j]^2) + for i in 2:(j - 1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + z[idx] = atanh(p) + idx += 1 + end end return z end + +_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W)) + +""" + _inv_link_chol_lkj(y) + +Inverse link function for cholesky factor. +""" +function _inv_link_chol_lkj(Y::AbstractMatrix) + K = LinearAlgebra.checksquare(Y) + + W = similar(Y) + + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(Y[i - 1, j]) + tmp = W[i - 1, j] + W[i - 1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j + 1):K + W[i, j] = 0 + end + end + + return W +end + +function _inv_link_chol_lkj(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + idx += 1 + tmp = W[i - 1, j] + W[i - 1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j + 1):K + W[i, j] = 0 + end + end + + return W +end + +function _logabsdetjac_inv_corr(Y::AbstractMatrix) + K = LinearAlgebra.checksquare(Y) + + result = float(zero(eltype(Y))) + for j in 2:K, i in 1:(j - 1) + @inbounds abs_y_i_j = abs(Y[i, j]) + result += + (K - i + 1) * ( + IrrationalConstants.logtwo - + (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) + ) + end + return result +end + +function _logabsdetjac_inv_corr(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + for (i, y_i) in enumerate(y) + abs_y_i = abs(y_i) + row_idx = vec_to_triu1_row_index(i) + result += + (K - row_idx + 1) * ( + IrrationalConstants.logtwo - + (abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i)) + ) + end + return result +end + +function _logabsdetjac_inv_chol(y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + result = float(zero(eltype(y))) + idx = 1 + @inbounds for j in 2:K + tmp = zero(result) + for _ in 1:(j - 1) + z = tanh(y[idx]) + logz = log(1 - z^2) + result += logz + (tmp / 2) + tmp += logz + idx += 1 + end + end + + return result +end diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 4a68bdd8..6e74523e 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -9,16 +9,14 @@ function replace_diag(f, X) end transform(b::PDBijector, X::AbstractMatrix{<:Real}) = pd_link(X) function pd_link(X) - Y = lower(parent(cholesky(X; check=true).L)) + Y = lower_triangular(parent(cholesky(X; check=true).L)) return replace_diag(log, Y) end -lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) function transform(ib::Inverse{PDBijector}, Y::AbstractMatrix{<:Real}) X = replace_diag(exp, Y) - return getpd(X) + return pd_from_lower(X) end -getpd(X) = LowerTriangular(X) * LowerTriangular(X)' function logabsdetjac(b::PDBijector, X::AbstractMatrix{<:Real}) T = eltype(X) diff --git a/src/chainrules.jl b/src/chainrules.jl index 26826e5a..e095cacf 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -156,5 +156,176 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM return y, _transform_inverse_ordered_adjoint end +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K - 1) * K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for j in 2:K + z[idx] = atanh(W[1, j]) + tmp = sqrt(1 - W[1, j]^2) + tmp_vec[idx] = tmp + idx += 1 + for i in 2:(j - 1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1, 1] = zero(eltype(Δz)) + + @inbounds for j in 2:K + idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) + ΔW[j, j] = 0 + Δtmp = zero(eltype(Δz)) + for i in (j - 1):-1:2 + tmp = tmp_vec[idx_up_to_prev_column + i - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i, j] / tmp^2 + + Δp = Δz[idx_up_to_prev_column + i] / (1 - p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp + end + ΔW[1, j] = + Δz[idx_up_to_prev_column + 1] / (1 - W[1, j]^2) - + Δtmp / sqrt(1 - W[1, j]^2) * W[1, j] + end + + return ChainRulesCore.NoTangent(), ΔW + end + + return z, pullback_link_chol_lkj +end + +function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular) + K = LinearAlgebra.checksquare(W) + N = ((K - 1) * K) ÷ 2 + + z = zeros(eltype(W), N) + tmp_vec = similar(z) + + idx = 1 + @inbounds for i in 2:K + z[idx] = atanh(W[i, 1]) + tmp = sqrt(1 - W[i, 1]^2) + tmp_vec[idx] = tmp + idx += 1 + for j in 2:(i - 1) + p = W[i, j] / tmp + tmp *= sqrt(1 - p^2) + tmp_vec[idx] = tmp + z[idx] = atanh(p) + idx += 1 + end + end + + function pullback_link_chol_lkj(Δz_thunked) + Δz = ChainRulesCore.unthunk(Δz_thunked) + + ΔW = similar(W) + + @inbounds ΔW[1, 1] = zero(eltype(Δz)) + + @inbounds for i in 2:K + idx_up_to_prev_row = ((i - 1) * (i - 2) ÷ 2) + ΔW[i, i] = 0 + Δtmp = zero(eltype(Δz)) + for j in (i - 1):-1:2 + tmp = tmp_vec[idx_up_to_prev_row + j - 1] + p = W[i, j] / tmp + ftmp = sqrt(1 - p^2) + d_ftmp_p = -p / ftmp + d_p_tmp = -W[i, j] / tmp^2 + + Δp = Δz[idx_up_to_prev_row + j] / (1 - p^2) + Δtmp * tmp * d_ftmp_p + ΔW[i, j] = Δp / tmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp + end + ΔW[i, 1] = + Δz[idx_up_to_prev_row + 1] / (1 - W[i, 1]^2) - + Δtmp / sqrt(1 - W[i, 1]^2) * W[i, 1] + end + + return ChainRulesCore.NoTangent(), ΔW + end + + return z, pullback_link_chol_lkj +end + +function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i - 1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i - 1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j + 1):K + W[i, j] = 0 + end + end + + function pullback_inv_link_chol_lkj(ΔW_thunked) + ΔW = ChainRulesCore.unthunk(ΔW_thunked) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) + Δtmp = ΔW[j, j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + tmp = tmp_vec[idx] + z = z_vec[idx] + + Δz = ΔW[i - 1, j] * tmp - Δtmp * tmp / sqrt(1 - z^2) * z + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i - 1, j] * z + Δtmp * sqrt(1 - z^2) + end + end + + return ChainRulesCore.NoTangent(), Δy + end + + return W, pullback_inv_link_chol_lkj +end + +function ChainRulesCore.rrule(::typeof(pd_from_upper), X::AbstractMatrix) + return UpperTriangular(X)' * UpperTriangular(X), + Δ_thunked -> begin + Δ = ChainRulesCore.unthunk(Δ_thunked) + Xu = UpperTriangular(X) + return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ') + end +end + # Fixes Zygote's issues with `@debug` ChainRulesCore.@non_differentiable _debug(::Any) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 47c58054..7e95e69c 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -1,7 +1,14 @@ module ReverseDiffCompat using ..ReverseDiff: - ReverseDiff, @grad, value, track, TrackedReal, TrackedVector, TrackedMatrix + ReverseDiff, + @grad, + value, + track, + TrackedReal, + TrackedVector, + TrackedMatrix, + @grad_from_chainrules using Requires, LinearAlgebra using ..Bijectors: @@ -20,13 +27,16 @@ import ..Bijectors: _simplex_inv_bijector, replace_diag, jacobian, - getpd, - lower, + pd_from_lower, + pd_from_upper, + lower_triangular, + upper_triangular, _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, - find_alpha + find_alpha, + cholesky_factor using ChainRulesCore: ChainRulesCore @@ -162,8 +172,8 @@ end end end -getpd(X::TrackedMatrix) = track(getpd, X) -@grad function getpd(X::AbstractMatrix) +pd_from_lower(X::TrackedMatrix) = track(pd_from_lower, X) +@grad function pd_from_lower(X::AbstractMatrix) Xd = value(X) return LowerTriangular(Xd) * LowerTriangular(Xd)', Δ -> begin @@ -171,10 +181,19 @@ getpd(X::TrackedMatrix) = track(getpd, X) return (LowerTriangular(Δ' * Xl + Δ * Xl),) end end -lower(A::TrackedMatrix) = track(lower, A) -@grad function lower(A::AbstractMatrix) + +@grad_from_chainrules pd_from_upper(X::TrackedMatrix) + +lower_triangular(A::TrackedMatrix) = track(lower_triangular, A) +@grad function lower_triangular(A::AbstractMatrix) + Ad = value(A) + return lower_triangular(Ad), Δ -> (lower_triangular(Δ),) +end + +upper_triangular(A::TrackedMatrix) = track(upper_triangular, A) +@grad function upper_triangular(A::AbstractMatrix) Ad = value(A) - return lower(Ad), Δ -> (lower(Δ),) + return upper_triangular(Ad), Δ -> (upper_triangular(Δ),) end function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} @@ -208,9 +227,23 @@ end return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) end +@grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) + +@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) +@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) + # NOTE: Probably doesn't work in complete generality. wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +if VERSION <= v"1.8.0-DEV.1526" + # HACK: This dispatch does not wrap X in Hermitian before calling cholesky. + # cholesky does not work with AbstractMatrix in julia versions before the compared one, + # and it would error with Hermitian{ReverseDiff.TrackedArray}. + # See commit when the fix was introduced : + # https://github.com/JuliaLang/julia/commit/635449dabee81bba315ab066627a98f856141969 + cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X)) +end + end diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index efa45650..72925fce 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -13,7 +13,7 @@ using ..Tracker: param import ..Bijectors -using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked +using ..Bijectors: Elementwise, SimplexBijector, Inverse, Stacked, _triu1_dim_from_length using ChainRulesCore: ChainRulesCore using LogExpFunctions: LogExpFunctions @@ -290,8 +290,8 @@ end (b::Elementwise{typeof(log)})(x::TrackedVector) = log.(x)::vectorof(float(eltype(x))) (b::Elementwise{typeof(log)})(x::TrackedMatrix) = log.(x)::matrixof(float(eltype(x))) -Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) -@grad function Bijectors.getpd(X::AbstractMatrix) +Bijectors.pd_from_lower(X::TrackedMatrix) = track(Bijectors.pd_from_lower, X) +@grad function Bijectors.pd_from_lower(X::AbstractMatrix) Xd = data(X) return Bijectors.LowerTriangular(Xd) * Bijectors.LowerTriangular(Xd)', Δ -> begin @@ -300,14 +300,67 @@ Bijectors.getpd(X::TrackedMatrix) = track(Bijectors.getpd, X) end end -Bijectors.lower(A::TrackedMatrix) = track(Bijectors.lower, A) -@grad function Bijectors.lower(A::AbstractMatrix) +Bijectors.lower_triangular(A::TrackedMatrix) = track(Bijectors.lower_triangular, A) +@grad function Bijectors.lower_triangular(A::AbstractMatrix) Ad = data(A) - return Bijectors.lower(Ad), Δ -> (Bijectors.lower(Δ),) + return Bijectors.lower_triangular(Ad), Δ -> (Bijectors.lower_triangular(Δ),) +end + +Bijectors._inv_link_chol_lkj(y::TrackedVector) = track(Bijectors._inv_link_chol_lkj, y) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedVector) + y = data(y_tracked) + K = _triu1_dim_from_length(length(y)) + + W = similar(y, K, K) + + z_vec = similar(y) + tmp_vec = similar(y) + + idx = 1 + @inbounds for j in 1:K + W[1, j] = 1 + for i in 2:j + z = tanh(y[idx]) + tmp = W[i - 1, j] + + z_vec[idx] = z + tmp_vec[idx] = tmp + idx += 1 + + W[i - 1, j] = z * tmp + W[i, j] = tmp * sqrt(1 - z^2) + end + for i in (j + 1):K + W[i, j] = 0 + end + end + + function pullback_inv_link_chol_lkj(ΔW) + LinearAlgebra.checksquare(ΔW) + + Δy = zero(y) + + @inbounds for j in 1:K + idx_up_to_prev_column = ((j - 1) * (j - 2) ÷ 2) + Δtmp = ΔW[j, j] + for i in j:-1:2 + idx = idx_up_to_prev_column + i - 1 + Δz = + ΔW[i - 1, j] * tmp_vec[idx] - + Δtmp * tmp_vec[idx] / sqrt(1 - z_vec[idx]^2) * z_vec[idx] + Δy[idx] = Δz / cosh(y[idx])^2 + Δtmp = ΔW[i - 1, j] * z_vec[idx] + Δtmp * sqrt(1 - z_vec[idx]^2) + end + end + + return (Δy,) + end + + return W, pullback_inv_link_chol_lkj end Bijectors._inv_link_chol_lkj(y::TrackedMatrix) = track(Bijectors._inv_link_chol_lkj, y) -@grad function Bijectors._inv_link_chol_lkj(y_tracked) +@grad function Bijectors._inv_link_chol_lkj(y_tracked::TrackedMatrix) y = data(y_tracked) K = LinearAlgebra.checksquare(y) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index 8533b58f..f0f23538 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -154,10 +154,10 @@ end end return pullback(_maximum, d) end -@adjoint function lower(A::AbstractMatrix) - return lower(A), Δ -> (lower(Δ),) +@adjoint function lower_triangular(A::AbstractMatrix) + return lower_triangular(A), Δ -> (lower_triangular(Δ),) end -@adjoint function getpd(X::AbstractMatrix) +@adjoint function pd_from_lower(X::AbstractMatrix) return LowerTriangular(X) * LowerTriangular(X)', Δ -> begin Xl = LowerTriangular(X) @@ -170,101 +170,3 @@ end return replace_diag(log, Y) end end - -@adjoint function _inv_link_chol_lkj(y) - K = LinearAlgebra.checksquare(y) - - w = similar(y) - - z_mat = similar(y) # cache for adjoint - tmp_mat = similar(y) - - @inbounds for j in 1:K - w[1, j] = 1 - for i in 2:j - z = tanh(y[i - 1, j]) - tmp = w[i - 1, j] - - z_mat[i, j] = z - tmp_mat[i, j] = tmp - - w[i - 1, j] = z * tmp - w[i, j] = tmp * sqrt(1 - z^2) - end - for i in (j + 1):K - w[i, j] = 0 - end - end - - function pullback_inv_link_chol_lkj(Δw) - LinearAlgebra.checksquare(Δw) - - Δy = zero(y) - - @inbounds for j in 1:K - Δtmp = Δw[j, j] - for i in j:-1:2 - Δz = - Δw[i - 1, j] * tmp_mat[i, j] - - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] - Δy[i - 1, j] = Δz / cosh(y[i - 1, j])^2 - Δtmp = Δw[i - 1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) - end - end - - return (Δy,) - end - - return w, pullback_inv_link_chol_lkj -end - -@adjoint function _link_chol_lkj(w) - K = LinearAlgebra.checksquare(w) - - z = similar(w) - - @inbounds z[1, 1] = 0 - - tmp_mat = similar(w) # cache for pullback. - - @inbounds for j in 2:K - z[1, j] = atanh(w[1, j]) - tmp = sqrt(1 - w[1, j]^2) - tmp_mat[1, j] = tmp - for i in 2:(j - 1) - p = w[i, j] / tmp - tmp *= sqrt(1 - p^2) - tmp_mat[i, j] = tmp - z[i, j] = atanh(p) - end - z[j, j] = 0 - end - - function pullback_link_chol_lkj(Δz) - LinearAlgebra.checksquare(Δz) - - Δw = similar(w) - - @inbounds Δw[1, 1] = zero(eltype(Δz)) - - @inbounds for j in 2:K - Δw[j, j] = 0 - Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j] - for i in (j - 1):-1:2 - p = w[i, j] / tmp_mat[i - 1, j] - ftmp = sqrt(1 - p^2) - d_ftmp_p = -p / ftmp - d_p_tmp = -w[i, j] / tmp_mat[i - 1, j]^2 - - Δp = Δz[i, j] / (1 - p^2) + Δtmp * tmp_mat[i - 1, j] * d_ftmp_p - Δw[i, j] = Δp / tmp_mat[i - 1, j] - Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp - end - Δw[1, j] = Δz[1, j] / (1 - w[1, j]^2) - Δtmp / sqrt(1 - w[1, j]^2) * w[1, j] - end - - return (Δw,) - end - - return z, pullback_link_chol_lkj -end diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 5f0a0d8d..eccbb64c 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -13,6 +13,9 @@ struct TransformedDistribution{D,B,V} <: function TransformedDistribution(d::MatrixDistribution, b) return new{typeof(d),typeof(b),Matrixvariate}(d, b) end + function TransformedDistribution(d::Distribution{CholeskyVariate}, b) + return new{typeof(d),typeof(b),CholeskyVariate}(d, b) + end end # fields may contain nested numerical parameters @@ -83,7 +86,8 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) bijector(d::PDMatDistribution) = PDBijector() bijector(d::MatrixBeta) = PDBijector() -bijector(d::LKJ) = CorrBijector() +bijector(d::LKJ) = VecCorrBijector() +bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo) function bijector(d::Distributions.ReshapedDistribution) inner_dims = size(d.dist) @@ -120,6 +124,13 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) return logpdf(td.dist, mappedarray(x -> x + ϵ, x)) + logjac end +function logpdf( + td::TransformedDistribution{T}, y::AbstractVector{<:Real} +) where {T<:Union{LKJ,LKJCholesky}} + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac +end + function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) return logpdf(td.dist, x) + logjac @@ -170,6 +181,12 @@ function _rand!(rng::AbstractRNG, td::MatrixTransformed, x::DenseMatrix{<:Real}) return x .= td.transform(x) end +function rand( + rng::AbstractRNG, td::TransformedDistribution{T} +) where {T<:Union{LKJ,LKJCholesky}} + return td.transform(rand(rng, td.dist)) +end + # utility stuff Distributions.params(td::Transformed) = Distributions.params(td.dist) function Base.maximum(td::UnivariateTransformed) diff --git a/src/utils.jl b/src/utils.jl index 8203e1b4..439a2d0c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,15 @@ aT_b(a::AbstractVector{<:Real}, b::AbstractVector{<:Real}) = dot(a, b) # flatten arrays with fallback for scalars _vec(x::AbstractArray{<:Real}) = vec(x) _vec(x::Real) = x + +# # Because `ReverseDiff` does not play well with structural matrices. +lower_triangular(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A)) +upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A)) + +pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)' +pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X) + +cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X))) +cholesky_factor(X::Cholesky) = X.U +cholesky_factor(X::UpperTriangular) = X +cholesky_factor(X::LowerTriangular) = X diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index ebb461f0..e639e4ea 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -19,4 +19,14 @@ test_rrule(Bijectors._transform_ordered, randn(5, 2)) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5))) test_rrule(Bijectors._transform_inverse_ordered, b(rand(5, 2))) + + # LKJ and LKJCholesky bijector + dist = LKJCholesky(3, 4) + x = rand(dist) + test_rrule(Bijectors._link_chol_lkj, x.U) + test_rrule(Bijectors._link_chol_lkj, x.L) + + b = bijector(dist) + y = b(x) + test_rrule(Bijectors._inv_link_chol_lkj, y) end diff --git a/test/ad/utils.jl b/test/ad/utils.jl index dc7d8234..bea823fd 100644 --- a/test/ad/utils.jl +++ b/test/ad/utils.jl @@ -4,17 +4,6 @@ const AD = get(ENV, "AD", "All") function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6) finitediff = FiniteDifferences.grad(central_fdm(5, 1), f, x)[1] - if AD == "All" || AD == "Tracker" - if :Tracker in broken - @test_broken Tracker.data(Tracker.gradient(f, x)[1]) ≈ finitediff rtol = rtol atol = - atol - else - ∇tracker = Tracker.gradient(f, x)[1] - @test Tracker.data(∇tracker) ≈ finitediff rtol = rtol atol = atol - @test Tracker.istracked(∇tracker) - end - end - if AD == "All" || AD == "ForwardDiff" if :ForwardDiff in broken @test_broken ForwardDiff.gradient(f, x) ≈ finitediff rtol = rtol atol = atol diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl new file mode 100644 index 00000000..71bfd8d7 --- /dev/null +++ b/test/bijectors/corr.jl @@ -0,0 +1,65 @@ +using Bijectors, DistributionsAD, LinearAlgebra, Test +using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector + +@testset "CorrBijector & VecCorrBijector" begin + for d in [1, 2, 5] + b = CorrBijector() + bvec = VecCorrBijector() + + dist = LKJ(d, 1) + x = rand(dist) + + y = b(x) + yvec = bvec(x) + + # Make sure that they represent the same thing. + @test Bijectors.triu1_to_vec(y) ≈ yvec + + # Check the inverse. + binv = inverse(b) + xinv = binv(y) + bvecinv = inverse(bvec) + xvecinv = bvecinv(yvec) + + @test xinv ≈ xvecinv + + # And finally that the `logabsdetjac` is the same. + @test logabsdetjac(bvec, x) ≈ logabsdetjac(b, x) + + # NOTE: `CorrBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) + + test_ad(x -> sum(bvec(bvecinv(x))), yvec) + end +end + +@testset "VecCholeskyBijector" begin + for d in [2, 5] + for dist in [LKJCholesky(d, 1, 'U'), LKJCholesky(d, 1, 'L')] + b = bijector(dist) + + b_lkj = VecCorrBijector() + x = rand(dist) + y = b(x) + y_lkj = b_lkj(x) + + @test y ≈ y_lkj + + binv = inverse(b) + xinv = binv(y) + binv_lkj = inverse(b_lkj) + xinv_lkj = binv_lkj(y_lkj) + + @test xinv.U ≈ cholesky(xinv_lkj).U + + test_ad(x -> sum(b(binv(x))), y) + + # test_bijector is commented out for now, + # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) + # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ee2401eb..cb8d9455 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,6 +46,7 @@ if GROUP == "All" || GROUP == "Interface" include("bijectors/ordered.jl") include("bijectors/pd.jl") include("bijectors/reshape.jl") + include("bijectors/corr.jl") end if GROUP == "All" || GROUP == "AD" diff --git a/test/transform.jl b/test/transform.jl index 15cd9178..00e2ea79 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -39,7 +39,15 @@ function single_sample_tests(dist) # Check that invlink is inverse of link. x = rand(dist) - @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol = 1e-9 + + if dist isa LKJCholesky + x_inv = @inferred Cholesky{Float64,Matrix{Float64}} invlink( + dist, link(dist, copy(x)) + ) + @test x_inv.UL ≈ x.UL atol = 1e-9 + else + @test @inferred(invlink(dist, link(dist, copy(x)))) ≈ x atol = 1e-9 + end # Check that link is inverse of invlink. Hopefully this just holds given the above... y = @inferred(link(dist, x)) @@ -66,7 +74,7 @@ function single_sample_tests(dist) # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) @test all( - isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(x)) for _ in 1:100]) + isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(y)) for _ in 1:100]) ) end end @@ -182,8 +190,8 @@ end end end -@testset "correlation matrix" begin - dist = LKJ(2, 1) +@testset "LKJ" begin + dist = LKJ(3, 1) single_sample_tests(dist) @@ -196,7 +204,23 @@ end LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1] ] J = ForwardDiff.jacobian(x -> link(dist, x), x) - J = J[upperinds, upperinds] + J = J[:, upperinds] + logpdf_turing = logpdf_with_trans(dist, x, true) + @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing +end + +@testset "LKJCholesky" begin + dist = LKJCholesky(3, 1) + + single_sample_tests(dist) + + x = rand(dist) + + upperinds = [ + LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1] + ] + J = ForwardDiff.jacobian(x -> link(dist, x), x.U) + J = J[:, upperinds] logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end