Skip to content

Commit

Permalink
Make cholesky handle AbstractMatrix (JuliaLang#44076)
Browse files Browse the repository at this point in the history
Co-authored-by: Sheehan Olver <[email protected]>
  • Loading branch information
2 people authored and LilithHafner committed Mar 8, 2022
1 parent 6f702c8 commit 3468124
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 24 deletions.
46 changes: 26 additions & 20 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ Base.iterate(C::CholeskyPivoted, ::Val{:done}) = nothing

# make a copy that allow inplace Cholesky factorization
@inline choltype(A) = promote_type(typeof(sqrt(oneunit(eltype(A)))), Float32)
@inline cholcopy(A) = copy_oftype(A, choltype(A))
@inline cholcopy(A::StridedMatrix) = copy_oftype(A, choltype(A))
@inline cholcopy(A::RealHermSymComplexHerm) = copy_oftype(A, choltype(A))
@inline cholcopy(A::AbstractMatrix) = copy_similar(A, choltype(A))

# _chol!. Internal methods for calling unpivoted Cholesky
## BLAS/LAPACK element types
Expand Down Expand Up @@ -269,9 +271,9 @@ function cholesky!(A::RealHermSymComplexHerm, ::NoPivot = NoPivot(); check::Bool
return Cholesky(C.data, A.uplo, info)
end

### for StridedMatrices, check that matrix is symmetric/Hermitian
### for AbstractMatrix, check that matrix is symmetric/Hermitian
"""
cholesky!(A::StridedMatrix, NoPivot(); check = true) -> Cholesky
cholesky!(A::AbstractMatrix, NoPivot(); check = true) -> Cholesky
The same as [`cholesky`](@ref), but saves space by overwriting the input `A`,
instead of creating a copy. An [`InexactError`](@ref) exception is thrown if
Expand All @@ -291,7 +293,7 @@ Stacktrace:
[...]
```
"""
function cholesky!(A::StridedMatrix, ::NoPivot = NoPivot(); check::Bool = true)
function cholesky!(A::AbstractMatrix, ::NoPivot = NoPivot(); check::Bool = true)
checksquare(A)
if !ishermitian(A) # return with info = -1 if not Hermitian
check && checkpositivedefinite(-1)
Expand Down Expand Up @@ -320,16 +322,16 @@ cholesky!(A::RealHermSymComplexHerm{<:Real}, ::RowMaximum; tol = 0.0, check::Boo
throw(ArgumentError("generic pivoted Cholesky factorization is not implemented yet"))
@deprecate cholesky!(A::RealHermSymComplexHerm{<:Real}, ::Val{true}; kwargs...) cholesky!(A, RowMaximum(); kwargs...) false

### for StridedMatrices, check that matrix is symmetric/Hermitian
### for AbstractMatrix, check that matrix is symmetric/Hermitian
"""
cholesky!(A::StridedMatrix, RowMaximum(); tol = 0.0, check = true) -> CholeskyPivoted
cholesky!(A::AbstractMatrix, RowMaximum(); tol = 0.0, check = true) -> CholeskyPivoted
The same as [`cholesky`](@ref), but saves space by overwriting the input `A`,
instead of creating a copy. An [`InexactError`](@ref) exception is thrown if the
factorization produces a number not representable by the element type of `A`,
e.g. for integer types.
"""
function cholesky!(A::StridedMatrix, ::RowMaximum; tol = 0.0, check::Bool = true)
function cholesky!(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = true)
checksquare(A)
if !ishermitian(A)
C = CholeskyPivoted(A, 'U', Vector{BlasInt}(),convert(BlasInt, 1),
Expand All @@ -350,7 +352,7 @@ end
Compute the Cholesky factorization of a dense symmetric positive definite matrix `A`
and return a [`Cholesky`](@ref) factorization. The matrix `A` can either be a [`Symmetric`](@ref) or [`Hermitian`](@ref)
[`StridedMatrix`](@ref) or a *perfectly* symmetric or Hermitian `StridedMatrix`.
[`AbstractMatrix`](@ref) or a *perfectly* symmetric or Hermitian `AbstractMatrix`.
The triangular Cholesky factor can be obtained from the factorization `F` via `F.L` and `F.U`,
where `A ≈ F.U' * F.U ≈ F.L * F.L'`.
Expand Down Expand Up @@ -397,11 +399,11 @@ julia> C.L * C.U == A
true
```
"""
cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}},
::NoPivot=NoPivot(); check::Bool = true) = cholesky!(cholcopy(A); check = check)
cholesky(A::AbstractMatrix, ::NoPivot=NoPivot(); check::Bool = true) =
cholesky!(cholcopy(A); check)
@deprecate cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, ::Val{false}; check::Bool = true) cholesky(A, NoPivot(); check) false

function cholesky(A::Union{StridedMatrix{Float16},RealHermSymComplexHerm{Float16,<:StridedMatrix}}, ::NoPivot=NoPivot(); check::Bool = true)
function cholesky(A::AbstractMatrix{Float16}, ::NoPivot=NoPivot(); check::Bool = true)
X = cholesky!(cholcopy(A); check = check)
return Cholesky{Float16}(X)
end
Expand All @@ -413,7 +415,7 @@ end
Compute the pivoted Cholesky factorization of a dense symmetric positive semi-definite matrix `A`
and return a [`CholeskyPivoted`](@ref) factorization. The matrix `A` can either be a [`Symmetric`](@ref)
or [`Hermitian`](@ref) [`StridedMatrix`](@ref) or a *perfectly* symmetric or Hermitian `StridedMatrix`.
or [`Hermitian`](@ref) [`AbstractMatrix`](@ref) or a *perfectly* symmetric or Hermitian `AbstractMatrix`.
The triangular Cholesky factor can be obtained from the factorization `F` via `F.L` and `F.U`,
and the permutation via `F.p`, where `A[F.p, F.p] ≈ Ur' * Ur ≈ Lr * Lr'` with `Ur = F.U[1:F.rank, :]`
Expand Down Expand Up @@ -463,11 +465,15 @@ julia> l == C.L && u == C.U
true
```
"""
cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}},
::RowMaximum; tol = 0.0, check::Bool = true) =
cholesky!(cholcopy(A), RowMaximum(); tol = tol, check = check)
cholesky(A::AbstractMatrix, ::RowMaximum; tol = 0.0, check::Bool = true) =
cholesky!(cholcopy(A), RowMaximum(); tol, check)
@deprecate cholesky(A::Union{StridedMatrix,RealHermSymComplexHerm{<:Real,<:StridedMatrix}}, ::Val{true}; tol = 0.0, check::Bool = true) cholesky(A, RowMaximum(); tol, check) false

function cholesky(A::AbstractMatrix{Float16}, ::RowMaximum; tol = 0.0, check::Bool = true)
X = cholesky!(cholcopy(A), RowMaximum(); tol, check)
return CholeskyPivoted{Float16}(X)
end

## Number
function cholesky(x::Number, uplo::Symbol=:U)
C, info = _chol!(x, uplo)
Expand Down Expand Up @@ -524,7 +530,7 @@ end
Base.propertynames(F::Cholesky, private::Bool=false) =
(:U, :L, :UL, (private ? fieldnames(typeof(F)) : ())...)

function getproperty(C::CholeskyPivoted{T}, d::Symbol) where T<:BlasFloat
function getproperty(C::CholeskyPivoted{T}, d::Symbol) where {T}
Cfactors = getfield(C, :factors)
Cuplo = getfield(C, :uplo)
if d === :U
Expand Down Expand Up @@ -595,7 +601,7 @@ function ldiv!(C::CholeskyPivoted{T}, B::StridedMatrix{T}) where T<:BlasFloat
B
end

function ldiv!(C::CholeskyPivoted, B::StridedVector)
function ldiv!(C::CholeskyPivoted, B::AbstractVector)
if C.uplo == 'L'
ldiv!(adjoint(LowerTriangular(C.factors)),
ldiv!(LowerTriangular(C.factors), permute!(B, C.piv)))
Expand All @@ -606,7 +612,7 @@ function ldiv!(C::CholeskyPivoted, B::StridedVector)
invpermute!(B, C.piv)
end

function ldiv!(C::CholeskyPivoted, B::StridedMatrix)
function ldiv!(C::CholeskyPivoted, B::AbstractMatrix)
n = size(C, 1)
for i in 1:size(B, 2)
permute!(view(B, 1:n, i), C.piv)
Expand All @@ -624,15 +630,15 @@ function ldiv!(C::CholeskyPivoted, B::StridedMatrix)
B
end

function rdiv!(B::StridedMatrix, C::Cholesky{<:Any,<:AbstractMatrix})
function rdiv!(B::AbstractMatrix, C::Cholesky{<:Any,<:AbstractMatrix})
if C.uplo == 'L'
return rdiv!(rdiv!(B, adjoint(LowerTriangular(C.factors))), LowerTriangular(C.factors))
else
return rdiv!(rdiv!(B, UpperTriangular(C.factors)), adjoint(UpperTriangular(C.factors)))
end
end

function LinearAlgebra.rdiv!(B::StridedMatrix, C::CholeskyPivoted)
function LinearAlgebra.rdiv!(B::AbstractMatrix, C::CholeskyPivoted)
n = size(C, 2)
for i in 1:size(B, 1)
permute!(view(B, i, 1:n), C.piv)
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,11 @@ function cholesky!(A::Diagonal, ::NoPivot = NoPivot(); check::Bool = true)
Cholesky(A, 'U', convert(BlasInt, info))
end
@deprecate cholesky!(A::Diagonal, ::Val{false}; check::Bool = true) cholesky!(A::Diagonal, NoPivot(); check) false

cholesky(A::Diagonal, ::NoPivot = NoPivot(); check::Bool = true) =
cholesky!(cholcopy(A), NoPivot(); check = check)
@deprecate cholesky(A::Diagonal, ::Val{false}; check::Bool = true) cholesky(A::Diagonal, NoPivot(); check) false

@inline cholcopy(A::Diagonal) = copy_oftype(A, choltype(A))
@inline cholcopy(A::RealHermSymComplexHerm{<:Real,<:Diagonal}) = copy_oftype(A, choltype(A))

function getproperty(C::Cholesky{<:Any,<:Diagonal}, d::Symbol)
Cfactors = getfield(C, :factors)
if d in (:U, :L, :UL)
Expand Down
7 changes: 7 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,10 @@ Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.cat_t(T, xs..
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)

# factorizations
function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true)
T = choltype(eltype(S))
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo))
cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check)
end
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -854,3 +854,12 @@ function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector)
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
return r
end

function cholesky(S::SymTridiagonal, ::NoPivot = NoPivot(); check::Bool = true)
if !ishermitian(S)
check && checkpositivedefinite(-1)
return Cholesky(S, 'U', convert(BlasInt, -1))
end
T = choltype(eltype(S))
cholesky!(Hermitian(Bidiagonal{T}(diag(S, 0), diag(S, 1), :U)), NoPivot(); check = check)
end
22 changes: 21 additions & 1 deletion stdlib/LinearAlgebra/test/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,15 @@ end
end

# test cholesky of 2x2 Strang matrix
S = Matrix{eltya}(SymTridiagonal([2, 2], [-1]))
S = SymTridiagonal{eltya}([2, 2], [-1])
for uplo in (:U, :L)
@test Matrix(@inferred cholesky(Hermitian(S, uplo))) S
if eltya <: Real
@test Matrix(@inferred cholesky(Symmetric(S, uplo))) S
end
end
@test Matrix(cholesky(S).U) [2 -1; 0 sqrt(eltya(3))] / sqrt(eltya(2))
@test Matrix(cholesky(S)) S

# test extraction of factor and re-creating original matrix
if eltya <: Real
Expand Down Expand Up @@ -371,6 +378,10 @@ end
@test D CD.L * CD.U
@test CD.info == 0

F = cholesky(Hermitian(I(3)))
@test F isa Cholesky{Float64,<:Diagonal}
@test Matrix(F) I(3)

# real, failing
@test_throws PosDefException cholesky(Diagonal([1.0, -2.0]))
Dnpd = cholesky(Diagonal([1.0, -2.0]); check = false)
Expand Down Expand Up @@ -502,6 +513,15 @@ end
@test B.U B32.U
@test B.L B32.L
@test B.UL B32.UL
@test Matrix(B) A
B = cholesky(A, RowMaximum())
B32 = cholesky(Float32.(A), RowMaximum())
@test B isa CholeskyPivoted{Float16,Matrix{Float16}}
@test B.U isa UpperTriangular{Float16, Matrix{Float16}}
@test B.L isa LowerTriangular{Float16, Matrix{Float16}}
@test B.U B32.U
@test B.L B32.L
@test Matrix(B) A
end

@testset "det and logdet" begin
Expand Down

0 comments on commit 3468124

Please sign in to comment.