From 89bcbe97556e68e470eed233c95bd092bdc9b130 Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Tue, 23 Apr 2024 19:13:39 +0200 Subject: [PATCH] Avoid constructing `MulAddMul`s --- src/blas/highlevel.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/blas/highlevel.jl b/src/blas/highlevel.jl index ad847ff3e..77bb89c33 100644 --- a/src/blas/highlevel.jl +++ b/src/blas/highlevel.jl @@ -142,10 +142,14 @@ else end # GEMV - +# legacy method +LinearAlgebra.generic_matvecmul!( + Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix, B::StridedROCVector, + _add::MulAddMul +) = LinearAlgebra.generic_matvecmul!(Y, tA, A, B, _add.alpha, _add.beta) function LinearAlgebra.generic_matvecmul!( Y::ROCVector, tA::AbstractChar, A::StridedROCMatrix, B::StridedROCVector, - _add::MulAddMul, + alpha::Number, beta::Number, ) mA, nA = tA == 'N' ? size(A) : reverse(size(A)) @@ -158,7 +162,6 @@ function LinearAlgebra.generic_matvecmul!( nA == 0 && return rmul!(Y, 0) T = eltype(Y) - alpha, beta = _add.alpha, _add.beta if alpha isa Union{Bool,T} && beta isa Union{Bool,T} α, β = T(alpha), T(beta) if T <: ROCBLASFloat && eltype(A) == eltype(B) == T @@ -171,7 +174,7 @@ function LinearAlgebra.generic_matvecmul!( end end end - LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta)) + LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta) end if VERSION < v"1.10.0-DEV.1365" @@ -191,13 +194,16 @@ end # # BLAS 3 # - -function LinearAlgebra.generic_matmatmul!( +# legacy method +LinearAlgebra.generic_matmatmul!( C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat, B::StridedROCVecOrMat, _add::MulAddMul, +) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) +function LinearAlgebra.generic_matmatmul!( + C::StridedROCVecOrMat, tA, tB, A::StridedROCVecOrMat, + B::StridedROCVecOrMat, alpha::Number, beta::Number, ) T = eltype(C) - alpha, beta = _add.alpha, _add.beta mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1) mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)