From 683d5c343ccd0c03e3a99a85bcbfa96a72205704 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 31 Jul 2023 09:29:00 -0400 Subject: [PATCH 1/5] WIP: Setup MKL direct factorizations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MWE: ```julia using LinearSolve, MKL_jll A = rand(4, 4); b = rand(4); u0 = zeros(4); lp = LinearProblem(A, b); truesol = solve(lp, LUFactorization()) mklsol = solve(lp, MKLLUFactorization()) @test truesol ≈ mklsol ``` The segfault can be reproduced just with the triangular solver. MWE without LinearSolve: ```julia using MKL_jll using LinearAlgebra: BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, chkargsok const usemkl = MKL_jll.is_available() function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false) require_one_based_indexing(A) check && chkfinite(A) chkstride1(A) m, n = size(A) lda = max(1,stride(A, 2)) ccall((:dgetrf_, MKL_jll.libmkl_rt), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) chkargsok(info[]) A, ipiv, info[] #Error code is stored in LU factorization type end A = rand(4,4); b = rand(4) getrf!(A) LU(getrf!(A)...) \ b ``` --- Project.toml | 2 ++ ext/LinearSolveMKLExt.jl | 54 ++++++++++++++++++++++++++++++++++++++++ src/LinearSolve.jl | 1 + src/extension_algs.jl | 10 ++++++++ 4 files changed, 67 insertions(+) create mode 100644 ext/LinearSolveMKLExt.jl diff --git a/Project.toml b/Project.toml index 16b066a54..8d3da3381 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" +MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" @@ -37,6 +38,7 @@ Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" LinearSolveCUDAExt = "CUDA" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" +LinearSolveMKLExt = "MKL_jll" LinearSolveKrylovKitExt = "KrylovKit" LinearSolvePardisoExt = "Pardiso" diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl new file mode 100644 index 000000000..a9d706d21 --- /dev/null +++ b/ext/LinearSolveMKLExt.jl @@ -0,0 +1,54 @@ +module LinearSolveMKLExt + +using MKL_jll +using LinearAlgebra: BlasInt, LU +using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, chkargsok +using LinearAlgebra +const usemkl = MKL_jll.is_available() + +using LinearSolve +using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase + +function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1,stride(A, 2)) + + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))) + end + + ccall((:dgetrf_, MKL_jll.libmkl_rt), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[] #Error code is stored in LU factorization type +end + +default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false +default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false + +function LinearSolve.init_cacheval(alg::MKLLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + ArrayInterface.lu_instance(convert(AbstractMatrix, A)) +end + +function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; + kwargs...) + A = cache.A + A = convert(AbstractMatrix, A) + if cache.isfresh + cacheval = @get_cacheval(cache, :MKLLUFactorization) + fact = LU(getrf!(A)...) + cache.cacheval = fact + cache.isfresh = false + end + y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization), cache.b) + SciMLBase.build_linear_solution(alg, y, nothing, cache) +end + +end \ No newline at end of file diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index af40bed4b..578f58a23 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -181,6 +181,7 @@ export HYPREAlgorithm export CudaOffloadFactorization export MKLPardisoFactorize, MKLPardisoIterate export PardisoJL +export MKLLUFactorization export OperatorAssumptions, OperatorCondition diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 2bdb9d2da..db3d89675 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -337,3 +337,13 @@ A wrapper over the IterativeSolvers.jl MINRES. """ function IterativeSolversJL_MINRES end + +""" +```julia +MKLLUFactorization() +``` + +A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace +to avoid allocations and does not require libblastrampoline. +""" +struct MKLLUFactorization <: AbstractFactorization end \ No newline at end of file From 16c775e1a04761ab246c8c2700ed9b74a11adbcf Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 31 Jul 2023 23:16:09 -0400 Subject: [PATCH 2/5] fix and test --- ext/LinearSolveMKLExt.jl | 10 +++------- test/basictests.jl | 1 + 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index a9d706d21..bc40d049d 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -2,7 +2,8 @@ module LinearSolveMKLExt using MKL_jll using LinearAlgebra: BlasInt, LU -using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, chkargsok +using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, + @blasfunc, chkargsok using LinearAlgebra const usemkl = MKL_jll.is_available() @@ -15,12 +16,7 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz chkstride1(A) m, n = size(A) lda = max(1,stride(A, 2)) - - if isempty(ipiv) - ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))) - end - - ccall((:dgetrf_, MKL_jll.libmkl_rt), Cvoid, + ccall((@blasfunc(dgetrf_), MKL_jll.libmkl_rt), Cvoid, (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), m, n, A, lda, ipiv, info) diff --git a/test/basictests.jl b/test/basictests.jl index 084533e0f..219b1d1ff 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -207,6 +207,7 @@ end QRFactorization(), SVDFactorization(), RFLUFactorization(), + MKLLUFactorization(), LinearSolve.defaultalg(prob1.A, prob1.b)) @testset "$alg" begin test_interface(alg, prob1, prob2) From d00d02593e641047278fc1cc649d303cc68d77a8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Mon, 31 Jul 2023 23:47:34 -0400 Subject: [PATCH 3/5] fix MKL test --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8d3da3381..9ddde9976 100644 --- a/Project.toml +++ b/Project.toml @@ -72,6 +72,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -80,4 +81,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI"] +test = ["Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "MKL_jll"] From 3c1e66d68770fac361c58027cbfa46b6e646a852 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 00:04:51 -0400 Subject: [PATCH 4/5] add missing using --- test/basictests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basictests.jl b/test/basictests.jl index 219b1d1ff..e24b87a6e 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,6 +1,6 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff using SciMLOperators -using IterativeSolvers, KrylovKit +using IterativeSolvers, KrylovKit, MKL_jll using Test import Random From aaf64d33b6f5f0a8a1b92f2122928b2756d49b39 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 01:42:38 -0400 Subject: [PATCH 5/5] add backwards compat --- src/LinearSolve.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 578f58a23..272639e34 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -124,6 +124,9 @@ end @require KrylovKit="0b1a1467-8014-51b9-945f-bf0ae24f4b77" begin include("../ext/LinearSolveKrylovKitExt.jl") end + @require MKL_jll="856f044c-d86e-5d09-b602-aeab76dc8ba7" begin + include("../ext/LinearSolveMKLExt.jl") + end end end