Skip to content

Commit

Permalink
Broadcast binary ops involving strided triangular (#55798)
Browse files Browse the repository at this point in the history
Currently, we evaluate expressions like `(A::UpperTriangular) +
(B::UpperTriangular)` using broadcasting if both `A` and `B` have
strided parents, and forward the summation to the parents otherwise.
This PR changes this to use broadcasting if either of the two has a
strided parent. This avoids accessing the parent corresponding to the
structural zero elements, as the index might not be initialized.

Fixes #55590

This isn't a general fix, as we still sum the parents if neither is
strided. However, it will address common cases.

This also improves performance, as we only need to loop over one half:
```julia
julia> using LinearAlgebra

julia> U = UpperTriangular(zeros(100,100));

julia> B = Bidiagonal(zeros(100), zeros(99), :U);

julia> @Btime $U + $B;
  35.530 μs (4 allocations: 78.22 KiB) # nightly
  13.441 μs (4 allocations: 78.22 KiB) # This PR
```
  • Loading branch information
jishnub authored Sep 19, 2024
1 parent a73ba3b commit b8093de
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 30 deletions.
8 changes: 4 additions & 4 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,10 @@ for f in (:+, :-)
@eval begin
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo)))
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B)
$f(A::SymTridiagonal, B::Symmetric) = Symmetric($f(A, B.data), sym_uplo(B.uplo))
$f(A::Symmetric, B::SymTridiagonal) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = Hermitian($f(A, B.data), sym_uplo(B.uplo))
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, sym_uplo(B.uplo)), B)
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, sym_uplo(A.uplo)))
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, sym_uplo(B.uplo)), B)
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, sym_uplo(A.uplo)))
end
end

Expand Down
91 changes: 65 additions & 26 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,35 +850,74 @@ fillstored!(A::UpperTriangular, x) = (fillband!(A.data, x, 0, size(A,2)-1);
fillstored!(A::UnitUpperTriangular, x) = (fillband!(A.data, x, 1, size(A,2)-1); A)

# Binary operations
+(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data + B.data)
+(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data + B.data)
+(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data + triu(B.data, 1) + I)
+(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data + tril(B.data, -1) + I)
+(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) + B.data + I)
+(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) + B.data + I)
+(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
+(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
# use broadcasting if the parents are strided, where we loop only over the triangular part
function +(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + B.data)
end
function +(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + B.data)
end
function +(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(A.data + triu(B.data, 1) + I)
end
function +(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(A.data + tril(B.data, -1) + I)
end
function +(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + B.data + I)
end
function +(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + B.data + I)
end
function +(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
UpperTriangular(triu(A.data, 1) + triu(B.data, 1) + 2I)
end
function +(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .+ B
LowerTriangular(tril(A.data, -1) + tril(B.data, -1) + 2I)
end
+(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) + copyto!(similar(parent(B)), B)

-(A::UpperTriangular, B::UpperTriangular) = UpperTriangular(A.data - B.data)
-(A::LowerTriangular, B::LowerTriangular) = LowerTriangular(A.data - B.data)
-(A::UpperTriangular, B::UnitUpperTriangular) = UpperTriangular(A.data - triu(B.data, 1) - I)
-(A::LowerTriangular, B::UnitLowerTriangular) = LowerTriangular(A.data - tril(B.data, -1) - I)
-(A::UnitUpperTriangular, B::UpperTriangular) = UpperTriangular(triu(A.data, 1) - B.data + I)
-(A::UnitLowerTriangular, B::LowerTriangular) = LowerTriangular(tril(A.data, -1) - B.data + I)
-(A::UnitUpperTriangular, B::UnitUpperTriangular) = UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
-(A::UnitLowerTriangular, B::UnitLowerTriangular) = LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)

# use broadcasting if the parents are strided, where we loop only over the triangular part
for op in (:+, :-)
for TM1 in (:LowerTriangular, :UnitLowerTriangular), TM2 in (:LowerTriangular, :UnitLowerTriangular)
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
end
for TM1 in (:UpperTriangular, :UnitUpperTriangular), TM2 in (:UpperTriangular, :UnitUpperTriangular)
@eval $op(A::$TM1{<:Any, <:StridedMaybeAdjOrTransMat}, B::$TM2{<:Any, <:StridedMaybeAdjOrTransMat}) = broadcast($op, A, B)
end
function -(A::UpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - B.data)
end
function -(A::LowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - B.data)
end
function -(A::UpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(A.data - triu(B.data, 1) - I)
end
function -(A::LowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(A.data - tril(B.data, -1) - I)
end
function -(A::UnitUpperTriangular, B::UpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - B.data + I)
end
function -(A::UnitLowerTriangular, B::LowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - B.data + I)
end
function -(A::UnitUpperTriangular, B::UnitUpperTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
UpperTriangular(triu(A.data, 1) - triu(B.data, 1))
end
function -(A::UnitLowerTriangular, B::UnitLowerTriangular)
(parent(A) isa StridedMatrix || parent(B) isa StridedMatrix) && return A .- B
LowerTriangular(tril(A.data, -1) - tril(B.data, -1))
end
-(A::AbstractTriangular, B::AbstractTriangular) = copyto!(similar(parent(A)), A) - copyto!(similar(parent(B)), B)

function kron(A::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat}, B::UpperTriangular{<:Number,<:StridedMaybeAdjOrTransMat})
C = UpperTriangular(Matrix{promote_op(*, eltype(A), eltype(B))}(undef, _kronsize(A, B)))
Expand Down
25 changes: 25 additions & 0 deletions stdlib/LinearAlgebra/test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,4 +1135,29 @@ end
end
end

@testset "partly iniitalized matrices" begin
a = Matrix{BigFloat}(undef, 2,2)
a[1] = 1; a[3] = 1; a[4] = 1
h = Hermitian(a)
s = Symmetric(a)
d = Diagonal([1,1])
symT = SymTridiagonal([1 1;1 1])
@test h+d == Array(h) + Array(d)
@test h+symT == Array(h) + Array(symT)
@test s+d == Array(s) + Array(d)
@test s+symT == Array(s) + Array(symT)
@test h-d == Array(h) - Array(d)
@test h-symT == Array(h) - Array(symT)
@test s-d == Array(s) - Array(d)
@test s-symT == Array(s) - Array(symT)
@test d+h == Array(d) + Array(h)
@test symT+h == Array(symT) + Array(h)
@test d+s == Array(d) + Array(s)
@test symT+s == Array(symT) + Array(s)
@test d-h == Array(d) - Array(h)
@test symT-h == Array(symT) - Array(h)
@test d-s == Array(d) - Array(s)
@test symT-s == Array(symT) - Array(s)
end

end # module TestSymmetric

1 comment on commit b8093de

@SRojas28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Typo in stdlib/LinearAlgebra/test/symmetric.jl
Location: Line 1138
Description: The word "iniitalized" should be corrected to "initialized."

Please sign in to comment.