From 4e6ccf683084ed282c9837cf39ba02ed2910608e Mon Sep 17 00:00:00 2001 From: Fredrik Ekre Date: Wed, 7 Jun 2017 22:33:20 +0200 Subject: [PATCH] add constructors for Symmetric(::Hermitian) and Hermitian(::Symmetric) (#22264) --- base/linalg/symmetric.jl | 33 +++++++++++++++++++-------------- test/linalg/symmetric.jl | 7 +++++++ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/base/linalg/symmetric.jl b/base/linalg/symmetric.jl index 519c330ff3475..a3c2aa39c435e 100644 --- a/base/linalg/symmetric.jl +++ b/base/linalg/symmetric.jl @@ -41,14 +41,6 @@ julia> Slower = Symmetric(A, :L) Note that `Supper` will not be equal to `Slower` unless `A` is itself symmetric (e.g. if `A == A.'`). """ Symmetric(A::AbstractMatrix, uplo::Symbol=:U) = (checksquare(A); Symmetric{eltype(A),typeof(A)}(A, char_uplo(uplo))) -Symmetric(A::Symmetric) = A -function Symmetric(A::Symmetric, uplo::Symbol) - if A.uplo == char_uplo(uplo) - return A - else - throw(ArgumentError("Cannot construct Symmetric; uplo doesn't match")) - end -end struct Hermitian{T,S<:AbstractMatrix} <: AbstractMatrix{T} data::S @@ -91,12 +83,25 @@ function Hermitian(A::AbstractMatrix, uplo::Symbol=:U) end Hermitian{eltype(A),typeof(A)}(A, char_uplo(uplo)) end -Hermitian(A::Hermitian) = A -function Hermitian(A::Hermitian, uplo::Symbol) - if A.uplo == char_uplo(uplo) - return A - else - throw(ArgumentError("Cannot construct Hermitian; uplo doesn't match")) + +for (S, H) in ((:Symmetric, :Hermitian), (:Hermitian, :Symmetric)) + @eval begin + $S(A::$S) = A + function $S(A::$S, uplo::Symbol) + if A.uplo == char_uplo(uplo) + return A + else + throw(ArgumentError("Cannot construct $($S); uplo doesn't match")) + end + end + $S(A::$H) = $S(A.data, Symbol(A.uplo)) + function $S(A::$H, uplo::Symbol) + if A.uplo == char_uplo(uplo) + return $S(A.data, Symbol(A.uplo)) + else + throw(ArgumentError("Cannot construct $($S); uplo doesn't match")) + end + end end end diff --git a/test/linalg/symmetric.jl b/test/linalg/symmetric.jl index e72017ee5dba1..17e457d8c1082 100644 --- a/test/linalg/symmetric.jl +++ b/test/linalg/symmetric.jl @@ -53,6 +53,13 @@ let n=10 @test Hermitian(Hermitian(asym, :U), :U) === Hermitian(asym, :U) @test_throws ArgumentError Symmetric(Symmetric(asym, :U), :L) @test_throws ArgumentError Hermitian(Hermitian(asym, :U), :L) + # mixed cases with Hermitian/Symmetric + @test Symmetric(Hermitian(asym, :U)) === Symmetric(asym, :U) + @test Hermitian(Symmetric(asym, :U)) === Hermitian(asym, :U) + @test Symmetric(Hermitian(asym, :U), :U) === Symmetric(asym, :U) + @test Hermitian(Symmetric(asym, :U), :U) === Hermitian(asym, :U) + @test_throws ArgumentError Symmetric(Hermitian(asym, :U), :L) + @test_throws ArgumentError Hermitian(Symmetric(asym, :U), :L) # similar @test isa(similar(Symmetric(asym)), Symmetric{eltya})