Skip to content

Commit

Permalink
Fix Broadcasting of Bidiagonal (#35281)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssikdar1 authored Jun 8, 2020
1 parent 07385ab commit 0e062e9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 6 deletions.
20 changes: 14 additions & 6 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElT
# Bidiagonal is tricky as we need to know if it's upper or lower. The promotion
# system will return Tridiagonal when there's more than one Bidiagonal, but when
# there's only one, we need to make figure out upper or lower
find_bidiagonal() = throw(ArgumentError("could not find Bidiagonal within broadcast expression"))
find_bidiagonal(a::Bidiagonal, rest...) = a
find_bidiagonal(bc::Broadcast.Broadcasted, rest...) = find_bidiagonal(find_bidiagonal(bc.args...), rest...)
find_bidiagonal(x, rest...) = find_bidiagonal(rest...)
merge_uplos(::Nothing, ::Nothing) = nothing
merge_uplos(a, ::Nothing) = a
merge_uplos(::Nothing, b) = b
merge_uplos(a, b) = a == b ? a : 'T'

find_uplo(a::Bidiagonal) = a.uplo
find_uplo(a) = nothing
find_uplo(bc::Broadcasted) = mapreduce(find_uplo, merge_uplos, bc.args, init=nothing)

function structured_broadcast_alloc(bc, ::Type{<:Bidiagonal}, ::Type{ElType}, n) where {ElType}
ex = find_bidiagonal(bc)
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), ex.uplo)
uplo = find_uplo(bc)
if uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n-1), Array{ElType}(undef, n), Array{ElType}(undef, n-1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), uplo)
end
structured_broadcast_alloc(bc, ::Type{<:SymTridiagonal}, ::Type{ElType}, n) where {ElType} =
SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1))
Expand Down
45 changes: 45 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,49 @@ end
@test L .+ UnitL .+ UnitU .+ U .+ D == L + UnitL + UnitU + U + D
@test L .+ U .+ D .+ D .+ D .+ D == L + U + D + D + D + D
end
@testset "Broadcast Returned Types" begin
# Issue 35245
N = 3
dV = rand(N)
evu = rand(N-1)
evl = rand(N-1)

Bu = Bidiagonal(dV, evu, :U)
Bl = Bidiagonal(dV, evl, :L)
T = Tridiagonal(evl, dV * 2, evu)

@test typeof(Bu .+ Bl) <: Tridiagonal
@test typeof(Bl .+ Bu) <: Tridiagonal
@test typeof(Bu .+ Bu) <: Bidiagonal
@test typeof(Bl .+ Bl) <: Bidiagonal
@test Bu .+ Bl == T
@test Bl .+ Bu == T
@test Bu .+ Bu == Bidiagonal(dV * 2, evu * 2, :U)
@test Bl .+ Bl == Bidiagonal(dV * 2, evl * 2, :L)


@test typeof(Bu .* Bl) <: Tridiagonal
@test typeof(Bl .* Bu) <: Tridiagonal
@test typeof(Bu .* Bu) <: Bidiagonal
@test typeof(Bl .* Bl) <: Bidiagonal

@test Bu .* Bl == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1))
@test Bl .* Bu == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1))
@test Bu .* Bu == Bidiagonal(dV .* dV, evu .* evu, :U)
@test Bl .* Bl == Bidiagonal(dV .* dV, evl .* evl, :L)

Bu2 = Bu .* 2
@test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U'
Bu2 = 2 .* Bu
@test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U'
Bl2 = Bl .* 2
@test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L'
Bu2 = 2 .* Bl
@test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L'

# Example of Nested Brodacasts
tmp = (1 .* 2) .* (Bidiagonal(1:3, 1:2, 'U') .* (3 .* 4)) .* (5 .* Bidiagonal(1:3, 1:2, 'L'))
@test typeof(tmp) <: Tridiagonal

end
end

0 comments on commit 0e062e9

Please sign in to comment.