Skip to content

Commit

Permalink
Change Cholesky to use Composite (#164)
Browse files Browse the repository at this point in the history
* Update Cholesky to use Composite

* Bump version to v0.3.3

* Add comment on why we call `Matrix`

* Replace `extern` with `unthunk`

* Move `Matrix` conversion

* Thunk whole `Composite` instead of the field
  • Loading branch information
nickrobinson251 authored Jan 30, 2020
1 parent 98c5458 commit fee130d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.3.2"
version = "0.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
28 changes: 17 additions & 11 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,26 +71,31 @@ end

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback(Ȳ)
∂X = @thunk(chol_blocked_rev(Matrix(Ȳ), Matrix(F.U), 25, true))
function cholesky_pullback(Ȳ::Composite{<:Cholesky})
∂X = if F.uplo === 'U'
@thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true))
else
@thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false))
end
return (NO_FIELDS, ∂X)
end
return F, cholesky_pullback
end

function rrule(::typeof(getproperty), F::Cholesky, x::Symbol)
function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function getproperty_cholesky_pullback(Ȳ)
if x === :U
C = Composite{T}
∂F = @thunk if x === :U
if F.uplo === 'U'
∂F = @thunk UpperTriangular(Ȳ)
C(U=UpperTriangular(Ȳ),)
else
∂F = @thunk LowerTriangular(Ȳ')
C(L=LowerTriangular(Ȳ'),)
end
elseif x === :L
if F.uplo === 'L'
∂F = @thunk LowerTriangular(Ȳ)
C(L=LowerTriangular(Ȳ),)
else
∂F = @thunk UpperTriangular(Ȳ')
C(U=UpperTriangular(Ȳ'),)
end
end
return NO_FIELDS, ∂F, DoesNotExist()
Expand Down Expand Up @@ -194,15 +199,15 @@ function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool
end

"""
chol_blocked_rev!(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
indicated by passing `upper = true`.
"""
function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::Integer, upper::Bool) where T<:Real
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real
n = checksquare(Σ̄)
tmp = Matrix{T}(undef, nb, nb)
k = n
Expand Down Expand Up @@ -252,5 +257,6 @@ function chol_blocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, nb::In
end

function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
return chol_blocked_rev!(copy(Σ̄), L, nb, upper)
# Convert to `Matrix`s because blas functions require StridedMatrix input.
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper)
end
4 changes: 2 additions & 2 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = extern(dF)
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(extern(dX), V)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = _fdm() do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
Expand Down

2 comments on commit fee130d

@nickrobinson251
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/8689

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.3 -m "<description of version>" fee130d924966e21ea3f6a5619092abf6be442ab
git push origin v0.3.3

Please sign in to comment.