Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cholesky factorisation #202

Merged
merged 8 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ jobs:
fail-fast: false
matrix:
version:
- "1.3" # Would test on 1.0 (LTS), but that causes trouble at test time as test can't downgrade main deps in 1.0
- "1" # Latest Release
os:
- ubuntu-latest
Expand All @@ -30,11 +29,6 @@ jobs:
arch: x86
- os: windows-latest
arch: x86
include:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop julia < 1.6 in CI

# Add a 1.5 job because that's what Invenia actually uses
- os: ubuntu-latest
version: 1.5
arch: x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedDims"
uuid = "356022a1-0364-5f58-8944-0da4b18d706f"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.49"
version = "0.3.0"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

major bump due to constrainting julia to 1.6 and above

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future, this should be a patch bump to be non-breaking (since the package is pre-1.0) unless one intends to support backports: https://github.com/SciML/ColPrac/pull/20/files


[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -20,7 +20,7 @@ ChainRulesTestUtils = "1"
CovarianceEstimation = "0.2.4"
Requires = "0.5, 1"
Tracker = "0.2.2"
julia = "1"
julia = "1.6"

[extras]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Expand Down
35 changes: 27 additions & 8 deletions src/functions_linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ Base.size(named::NamedFactorization) = size(parent(named))
Base.propertynames(named::NamedFactorization; kwargs...) = propertynames(parent(named))

# Factorization type specific initial iterate calls
Base.iterate(named::NamedFactorization{L, T, <:LU}) where {L, T} = (named.L, Val(:U))
Base.iterate(named::NamedFactorization{L, T, <:LQ}) where {L, T} = (named.L, Val(:Q))
Base.iterate(named::NamedFactorization{L, T, <:SVD}) where {L, T} = (named.U, Val(:S))
Base.iterate(named::NamedFactorization{L,T,<:LU}) where {L,T} = (named.L, Val(:U))
Base.iterate(named::NamedFactorization{L,T,<:LQ}) where {L,T} = (named.L, Val(:Q))
Base.iterate(named::NamedFactorization{L,T,<:SVD}) where {L,T} = (named.U, Val(:S))
Base.iterate(named::NamedFactorization{L,T,<:Cholesky}) where {L,T} = (named.L, Val(:U))
function Base.iterate(
named::NamedFactorization{L, T, <:Union{QR, LinearAlgebra.QRCompactWY, QRPivoted}}
) where {L, T}
named::NamedFactorization{L,T,<:Union{QR,LinearAlgebra.QRCompactWY,QRPivoted}}
) where {L,T}
return (named.Q, Val(:R))
end

Expand All @@ -40,9 +41,11 @@ function Base.iterate(named::NamedFactorization, st::Val{D}) where D
end

# Convenience constructors
for func in (:lu, :lu!, :lq, :lq!, :svd, :svd!, :qr, :qr!)
for func in (:lu, :lu!, :lq, :lq!, :svd, :svd!, :qr, :qr!, :cholesky)
@eval begin
function LinearAlgebra.$func(nda::NamedDimsArray{L, T}, args...; kwargs...) where {L, T}
function LinearAlgebra.$func(
nda::NamedDimsArray{L,T}, args...; kwargs...
) where {L,T}
return NamedFactorization{L}($func(parent(nda), args...; kwargs...))
end
end
Expand Down Expand Up @@ -82,8 +85,24 @@ function Base.getproperty(fact::NamedFactorization{L, T, <:LQ}, d::Symbol) where
end
end

# cholesky

function Base.getproperty(fact::NamedFactorization{L,T,<:Cholesky}, d::Symbol) where {L,T}
inner = getproperty(parent(fact), d)
n1, n2 = L
return d in (:L, :U) ? NamedDimsArray{(n1, n2)}(inner) : inner
end
function NamedFactorization{L}(fact::Cholesky{T}) where {L,T}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define a new constructor to handle the dimension mismatch

n1, n2 = L
return if isequal(n1, n2)
NamedFactorization{L,T,Cholesky{T}}(fact)
else
throw(DimensionMismatch("$n1 != $n2"))
end
end
AlexRobson marked this conversation as resolved.
Show resolved Hide resolved

## svd
function Base.getproperty(fact::NamedFactorization{L, T, <:SVD}, d::Symbol) where {L, T}
function Base.getproperty(fact::NamedFactorization{L,T,<:SVD}, d::Symbol) where {L,T}
inner = getproperty(parent(fact), d)
n1, n2 = L
# Naming based off the SVD visualization on wikipedia
Expand Down
37 changes: 31 additions & 6 deletions test/functions_linearalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@ if !isdefined(@__MODULE__, :ColumnNorm)
NoPivot() = Val(false)
end

function baseline_tests(fact, identity)
_test_data(::Val{:rectangle}) = [1.0 2 3; 4 5 6];
function _test_data(::Val{:pdmat})
return [8.0 7.0 6.0 5.0; 7.0 8.0 6.0 6.0; 6.0 6.0 6.0 5.0; 5.0 6.0 5.0 5.0]
end
_test_names(::Val{:rectangle}) = (:foo, :bar)
_test_names(::Val{:pdmat}) = (:foo, :foo)
AlexRobson marked this conversation as resolved.
Show resolved Hide resolved

function baseline_tests(fact, identity; test_data_type=:rectangle)
# A set of generic tests to ensure that our components don't accidentally reverse the
# `:foo` and `:bar` labels for any components
@testset "Baseline" begin
names = (:foo, :bar)
sz = (2, 3)
data = [1.0 2 3; 4 5 6]
names = _test_names(Val{test_data_type}())
data = _test_data(Val{test_data_type}())
nda = NamedDimsArray{names}(data)

base_fact = fact(data)
Expand All @@ -38,10 +44,11 @@ function baseline_tests(fact, identity)
@test size(_base) == size(_named)

# If our property is a NamedDimsArray make sure that the names make sense
_named isa NamedDimsArray && @testset "Test name for dim $d" for d in 1:ndims(_named)
_named isa NamedDimsArray && @testset "Test name for dim $d" for d in
1:ndims(_named)
# Don't think it make sense for an factorization to produce properties with
# dimension sizes outside 1, 2 or 3
@test size(_named, d) in (1, 2, 3)
AlexRobson marked this conversation as resolved.
Show resolved Hide resolved
@test d in (1, 2, 3)

if size(_named, d) == 1
# Neither name makes sense here
Expand All @@ -52,6 +59,9 @@ function baseline_tests(fact, identity)
elseif size(_named, d) == 3
# Name must either be :bar or :_
@test dimnames(_named, d) in (:bar, :_)
elseif size(_named, d) == 4
# Name can only be foo, as this is the pdmat case
@test dimnames(_named, d) in (:foo,)
end
end
end
Expand Down Expand Up @@ -143,6 +153,21 @@ end
end
end

@testset "cholesky" begin
baseline_tests(cholesky, S -> S.L * S.L'; test_data_type=:pdmat)
baseline_tests(cholesky, S -> S.U' * S.U; test_data_type=:pdmat)

# Explicit `dimnames` tests for readability
nda = NamedDimsArray{(:foo, :foo)}(_test_data(Val{:pdmat}()))
nda_mismatch = NamedDimsArray{(:foo, :bar)}(_test_data(Val{:pdmat}()))
x = cholesky(nda)
@test size(x) == size(parent(x))
@test dimnames(x.L) == (:foo, :foo)
@test dimnames(x.U) == (:foo, :foo)

@test_throws DimensionMismatch cholesky(nda_mismatch)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now throw an error if there is a dim-mismatch.

end

@testset "#164 factorization eltype not same as input eltype" begin
# https://github.com/invenia/NamedDims.jl/issues/164
nda = NamedDimsArray{(:foo, :bar)}([1 2 3; 4 5 6; 7 8 9]) # Int eltype
Expand Down