diff --git a/stdlib/LinearAlgebra/src/structuredbroadcast.jl b/stdlib/LinearAlgebra/src/structuredbroadcast.jl index 3b5ba2475aeb1..a665e21731752 100644 --- a/stdlib/LinearAlgebra/src/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/src/structuredbroadcast.jl @@ -60,13 +60,21 @@ structured_broadcast_alloc(bc, ::Type{<:Diagonal}, ::Type{ElType}, n) where {ElT # Bidiagonal is tricky as we need to know if it's upper or lower. The promotion # system will return Tridiagonal when there's more than one Bidiagonal, but when # there's only one, we need to make figure out upper or lower -find_bidiagonal() = throw(ArgumentError("could not find Bidiagonal within broadcast expression")) -find_bidiagonal(a::Bidiagonal, rest...) = a -find_bidiagonal(bc::Broadcast.Broadcasted, rest...) = find_bidiagonal(find_bidiagonal(bc.args...), rest...) -find_bidiagonal(x, rest...) = find_bidiagonal(rest...) +merge_uplos(::Nothing, ::Nothing) = nothing +merge_uplos(a, ::Nothing) = a +merge_uplos(::Nothing, b) = b +merge_uplos(a, b) = a == b ? a : 'T' + +find_uplo(a::Bidiagonal) = a.uplo +find_uplo(a) = nothing +find_uplo(bc::Broadcasted) = mapreduce(find_uplo, merge_uplos, bc.args, init=nothing) + function structured_broadcast_alloc(bc, ::Type{<:Bidiagonal}, ::Type{ElType}, n) where {ElType} - ex = find_bidiagonal(bc) - return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), ex.uplo) + uplo = find_uplo(bc) + if uplo == 'T' + return Tridiagonal(Array{ElType}(undef, n-1), Array{ElType}(undef, n), Array{ElType}(undef, n-1)) + end + return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1), uplo) end structured_broadcast_alloc(bc, ::Type{<:SymTridiagonal}, ::Type{ElType}, n) where {ElType} = SymTridiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n-1)) diff --git a/stdlib/LinearAlgebra/test/structuredbroadcast.jl b/stdlib/LinearAlgebra/test/structuredbroadcast.jl index 72d304ac6da3b..b8f5e97311588 100644 --- a/stdlib/LinearAlgebra/test/structuredbroadcast.jl +++ b/stdlib/LinearAlgebra/test/structuredbroadcast.jl @@ -160,4 +160,49 @@ end @test L .+ UnitL .+ UnitU .+ U .+ D == L + UnitL + UnitU + U + D @test L .+ U .+ D .+ D .+ D .+ D == L + U + D + D + D + D end +@testset "Broadcast Returned Types" begin + # Issue 35245 + N = 3 + dV = rand(N) + evu = rand(N-1) + evl = rand(N-1) + + Bu = Bidiagonal(dV, evu, :U) + Bl = Bidiagonal(dV, evl, :L) + T = Tridiagonal(evl, dV * 2, evu) + + @test typeof(Bu .+ Bl) <: Tridiagonal + @test typeof(Bl .+ Bu) <: Tridiagonal + @test typeof(Bu .+ Bu) <: Bidiagonal + @test typeof(Bl .+ Bl) <: Bidiagonal + @test Bu .+ Bl == T + @test Bl .+ Bu == T + @test Bu .+ Bu == Bidiagonal(dV * 2, evu * 2, :U) + @test Bl .+ Bl == Bidiagonal(dV * 2, evl * 2, :L) + + + @test typeof(Bu .* Bl) <: Tridiagonal + @test typeof(Bl .* Bu) <: Tridiagonal + @test typeof(Bu .* Bu) <: Bidiagonal + @test typeof(Bl .* Bl) <: Bidiagonal + + @test Bu .* Bl == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1)) + @test Bl .* Bu == Tridiagonal(zeros(N-1), dV .* dV, zeros(N-1)) + @test Bu .* Bu == Bidiagonal(dV .* dV, evu .* evu, :U) + @test Bl .* Bl == Bidiagonal(dV .* dV, evl .* evl, :L) + + Bu2 = Bu .* 2 + @test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U' + Bu2 = 2 .* Bu + @test typeof(Bu2) <: Bidiagonal && Bu2.uplo == 'U' + Bl2 = Bl .* 2 + @test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L' + Bu2 = 2 .* Bl + @test typeof(Bl2) <: Bidiagonal && Bl2.uplo == 'L' + + # Example of Nested Brodacasts + tmp = (1 .* 2) .* (Bidiagonal(1:3, 1:2, 'U') .* (3 .* 4)) .* (5 .* Bidiagonal(1:3, 1:2, 'L')) + @test typeof(tmp) <: Tridiagonal + +end end