diff --git a/src/blas/linalg.jl b/src/blas/linalg.jl index e3d6a3a6..d064d05d 100644 --- a/src/blas/linalg.jl +++ b/src/blas/linalg.jl @@ -301,3 +301,11 @@ LinearAlgebra.rdiv!(A::CuMatrix{T}, B::Transpose{T,<:LowerTriangular{T, <:CuMatr CUBLAS.trsm!('R', 'L', 'T', 'N', one(T), parent(parent(B)), A) LinearAlgebra.rdiv!(A::CuMatrix{T}, B::Transpose{T,<:UnitLowerTriangular{T, <:CuMatrix{T}}}) where T<:CublasFloat = CUBLAS.trsm!('R', 'L', 'T', 'U', one(T), parent(parent(B)), A) + +# Direct BLAS calls +for T in Base.uniontypes(CublasFloat) # needed to avoid ambiguous method error + @eval LinearAlgebra.BLAS.trmm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::$T, A::CuMatrix{$T}, B::CuMatrix{$T}) = + CuArrays.CUBLAS.trmm!(side, uplo, transa, diag, alpha, A, B, B) + @eval LinearAlgebra.BLAS.trsm!(side::AbstractChar, uplo::AbstractChar, transa::AbstractChar, diag::AbstractChar, alpha::$T, A::CuMatrix{$T}, B::CuMatrix{$T}) = + CuArrays.CUBLAS.trsm!(side, uplo, transa, diag, alpha, A, B) +end \ No newline at end of file diff --git a/test/blas.jl b/test/blas.jl index 07488f53..17439bc1 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -781,6 +781,19 @@ end # level 1 testset @test bC ≈ h_C end end + + @testset "BLAS.trmm!" begin + A = copy(A) + B = copy(B) + dA = CuArray(A) + dB = CuArray(B) + dC = LinearAlgebra.BLAS.trmm!('L','U','N','N',one(elty),dA,dB) + C = LinearAlgebra.BLAS.trmm!('L','U','N','N',one(elty),A,B) + @test A ≈ Array(dA) + @test B ≈ Array(dB) + @test C ≈ Array(dC) + end + B = rand(elty,m,n) C = rand(elty,m,n) d_B = CuArray(B)