From 0c620d8cd656119ffc193bb5b81837be8bd640bf Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Thu, 17 Mar 2022 12:26:10 +0100 Subject: [PATCH] Fix performance issue with diagonal multiplication Co-authored-by: Dilum Aluthge --- stdlib/LinearAlgebra/src/diagonal.jl | 111 ++++++++++++++++++--------- stdlib/LinearAlgebra/src/generic.jl | 3 +- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 4b7d9bd9d4af1..832f1b49c819f 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -276,38 +276,91 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix}) lmul!(D, At) end -@inline function __muldiag!(out, D::Diagonal, B, alpha, beta) - if iszero(beta) - out .= (D.diag .* B) .*ₛ alpha +# in __muldiag! below we unroll the loops manually, since broadcasting may be unable to +# prove that they are vectorizable +function __muldiag!(out, D::Diagonal, B, alpha, beta) + # TODO: check if this code can be replaced by a single line + # out .= (D.diag .* B) .*ₛ alpha .+ out .*ₛ beta + require_one_based_indexing(out) + if iszero(alpha) + _rmul_or_fill!(out, beta) else - out .= (D.diag .* B) .*ₛ alpha .+ out .* beta + if iszero(beta) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + out[i,j] = D.diag[i] * B[i,j] * alpha + end + end + else + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta + end + end + end end return out end - -@inline function __muldiag!(out, A, D::Diagonal, alpha, beta) - if iszero(beta) - out .= (A .* permutedims(D.diag)) .*ₛ alpha +function __muldiag!(out, A, D::Diagonal, alpha, beta) + # TODO: check if this code can be replaced by a single line + # out .= (B .* permutedims(D.diag)) .*ₛ alpha .+ out .*ₛ beta + require_one_based_indexing(out) + if iszero(alpha) + _rmul_or_fill!(out, beta) else - out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta + if iszero(beta) + @inbounds for j in axes(A, 2) + dja = D.diag[j] * alpha + @simd for i in axes(A, 1) + out[i,j] = A[i,j] * dja + end + end + else + @inbounds for j in axes(A, 2) + dja = D.diag[j] * alpha + @simd for i in axes(A, 1) + out[i,j] = A[i,j] * dja + out[i,j] * beta + end + end + end end return out end - -@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta) - if iszero(beta) - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha +function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta) + # TODO: check if this code can be replaced by a single line + # out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .*ₛ beta + d1 = D1.diag + d2 = D2.diag + if iszero(alpha) + _rmul_or_fill!(out.diag, beta) else - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta + if iszero(beta) + @inbounds @simd for i in eachindex(out.diag) + out.diag[i] = d1[i] * d2[i] * alpha + end + else + @inbounds @simd for i in eachindex(out.diag) + out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta + end + end + end + return out +end +function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) + require_one_based_indexing(out) + mA = size(D1, 1) + d1 = D1.diag + d2 = D2.diag + _rmul_or_fill!(out, beta) + if !iszero(alpha) + @inbounds @simd for i in 1:mA + out[i,i] += d1[i] * d2[i] * alpha + end end return out end -# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments -@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) = - mul!(out, D1, D2, alpha, beta) - -@inline function _muldiag!(out, A, B, alpha, beta) +function _muldiag!(out, A, B, alpha, beta) _muldiag_size_check(out, A, B) __muldiag!(out, A, B, alpha, beta) return out @@ -332,24 +385,8 @@ end @inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = _muldiag!(C, Da, Db, alpha, beta) -function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) - _muldiag_size_check(C, Da, Db) - require_one_based_indexing(C) - mA = size(Da, 1) - da = Da.diag - db = Db.diag - _rmul_or_fill!(C, beta) - if iszero(beta) - @inbounds @simd for i in 1:mA - C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha - end - else - @inbounds @simd for i in 1:mA - C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha - end - end - return C -end +mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = + _muldiag!(C, Da, Db, alpha, beta) _init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) = (_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B)))))) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index aa38419614b73..159c3b5db8843 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -8,8 +8,7 @@ # inside this function. function *ₛ end Broadcast.broadcasted(::typeof(*ₛ), out, beta) = - iszero(beta::Number) ? false : - isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta) + iszero(beta::Number) ? false : broadcasted(*, out, beta) """ MulAddMul(alpha, beta)