Skip to content

Commit

Permalink
Fix tr for Symmetric/Hermitian block matrices (#55522)
Browse files Browse the repository at this point in the history
Since `Symmetric` and `Hermitian` symmetrize the diagonal elements of
the parent, we can't forward `tr` to the parent unless it is already
symmetric. This limits the existing `tr` methods to matrices of
`Number`s, which is the common use-case. `tr` for `Symmetric` block
matrices would now use the fallback implementation that explicitly
computes the `diag`.
This resolves the following discrepancy:
```julia
julia> S = Symmetric(fill([1 2; 3 4], 3, 3))
3×3 Symmetric{AbstractMatrix, Matrix{Matrix{Int64}}}:
 [1 2; 2 4]  [1 2; 3 4]  [1 2; 3 4]
 [1 3; 2 4]  [1 2; 2 4]  [1 2; 3 4]
 [1 3; 2 4]  [1 3; 2 4]  [1 2; 2 4]

julia> tr(S)
2×2 Matrix{Int64}:
 3   6
 9  12

julia> sum(diag(S))
2×2 Symmetric{Int64, Matrix{Int64}}:
 3   6
 6  12
```
  • Loading branch information
jishnub authored Aug 19, 2024
1 parent 306cee7 commit 9738bc7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 2 additions & 2 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
Base.copy(A::Transpose{<:Any,<:Hermitian}) =
Hermitian(copy(transpose(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))

tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian) = real(tr(A.data))
tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
tr(A::Hermitian{<:Number}) = real(tr(A.data))

Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))
Expand Down
11 changes: 11 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1116,4 +1116,15 @@ end
end
end

@testset "tr for block matrices" begin
m = [1 2; 3 4]
for b in (m, m * (1 + im))
M = fill(b, 3, 3)
for ST in (Symmetric, Hermitian)
S = ST(M)
@test tr(S) == sum(diag(S))
end
end
end

end # module TestSymmetric

0 comments on commit 9738bc7

Please sign in to comment.