diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 4b7d9bd9d4af1..d608c2f9f8b68 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -241,8 +241,8 @@ end (*)(D::Diagonal, A::AbstractMatrix) = mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A) -rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D) -lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B) +rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) +lmul!(D::Diagonal, B::AbstractVecOrMat) = @inline mul!(B, D, B) #TODO: It seems better to call (D' * adjA')' directly? function *(adjA::Adjoint{<:Any,<:AbstractMatrix}, D::Diagonal) @@ -277,35 +277,80 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix}) end @inline function __muldiag!(out, D::Diagonal, B, alpha, beta) - if iszero(beta) - out .= (D.diag .* B) .*ₛ alpha + require_one_based_indexing(out) + if iszero(alpha) + _rmul_or_fill!(out, beta) else - out .= (D.diag .* B) .*ₛ alpha .+ out .* beta + if iszero(beta) + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + out[i,j] = D.diag[i] * B[i,j] * alpha + end + end + else + @inbounds for j in axes(B, 2) + @simd for i in axes(B, 1) + out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta + end + end + end end return out end - @inline function __muldiag!(out, A, D::Diagonal, alpha, beta) - if iszero(beta) - out .= (A .* permutedims(D.diag)) .*ₛ alpha + require_one_based_indexing(out) + if iszero(alpha) + _rmul_or_fill!(out, beta) else - out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta + if iszero(beta) + @inbounds for j in axes(A, 2) + dja = D.diag[j] * alpha + @simd for i in axes(A, 1) + out[i,j] = A[i,j] * dja + end + end + else + @inbounds for j in axes(A, 2) + dja = D.diag[j] * alpha + @simd for i in axes(A, 1) + out[i,j] = A[i,j] * dja + out[i,j] * beta + end + end + end end return out end - @inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta) - if iszero(beta) - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha + d1 = D1.diag + d2 = D2.diag + if iszero(alpha) + _rmul_or_fill!(out.diag, beta) else - out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta + if iszero(beta) + @inbounds @simd for i in eachindex(out.diag) + out.diag[i] = d1[i] * d2[i] * alpha + end + else + @inbounds @simd for i in eachindex(out.diag) + out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta + end + end + end + return out +end +@inline function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) + require_one_based_indexing(out) + mA = size(D1, 1) + d1 = D1.diag + d2 = D2.diag + _rmul_or_fill!(out, beta) + if !iszero(alpha) + @inbounds @simd for i in 1:mA + out[i,i] += d1[i] * d2[i] * alpha + end end return out end - -# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments -@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) = - mul!(out, D1, D2, alpha, beta) @inline function _muldiag!(out, A, B, alpha, beta) _muldiag_size_check(out, A, B) @@ -332,24 +377,8 @@ end @inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = _muldiag!(C, Da, Db, alpha, beta) -function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) - _muldiag_size_check(C, Da, Db) - require_one_based_indexing(C) - mA = size(Da, 1) - da = Da.diag - db = Db.diag - _rmul_or_fill!(C, beta) - if iszero(beta) - @inbounds @simd for i in 1:mA - C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha - end - else - @inbounds @simd for i in 1:mA - C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha - end - end - return C -end +mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) = + _muldiag!(C, Da, Db, alpha, beta) _init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) = (_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B)))))) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index aa38419614b73..159c3b5db8843 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -8,8 +8,7 @@ # inside this function. function *ₛ end Broadcast.broadcasted(::typeof(*ₛ), out, beta) = - iszero(beta::Number) ? false : - isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta) + iszero(beta::Number) ? false : broadcasted(*, out, beta) """ MulAddMul(alpha, beta)