From 50cf3419bd6254898233291ee2d84288d0ca3816 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 08:22:18 -0400 Subject: [PATCH 01/10] Support using Metal.jl for the LU factorization --- Project.toml | 7 +++++-- ext/LinearSolveMetalExt.jl | 31 +++++++++++++++++++++++++++++++ src/LinearSolve.jl | 1 + src/extension_algs.jl | 12 +++++++++++- 4 files changed, 48 insertions(+), 3 deletions(-) create mode 100644 ext/LinearSolveMetalExt.jl diff --git a/Project.toml b/Project.toml index 4670ceee9..2e865fc75 100644 --- a/Project.toml +++ b/Project.toml @@ -30,17 +30,19 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" HYPRE = "b5ffcf37-a2bd-41ab-a3da-4bd9bc8ad771" -MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" +MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] LinearSolveCUDAExt = "CUDA" LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" -LinearSolveMKLExt = "MKL_jll" LinearSolveKrylovKitExt = "KrylovKit" +LinearSolveMKLExt = "MKL_jll" +LinearSolveMetalExt = "Metal" LinearSolvePardisoExt = "Pardiso" [compat] @@ -74,6 +76,7 @@ IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" MultiFloats = "bdf0d083-296b-4888-a5b6-7498122e68a5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/ext/LinearSolveMetalExt.jl b/ext/LinearSolveMetalExt.jl new file mode 100644 index 000000000..3616f9663 --- /dev/null +++ b/ext/LinearSolveMetalExt.jl @@ -0,0 +1,31 @@ +module LinearSolveMetalExt + +using Metal, LinearSolve +using LinearAlgebra, SciMLBase +using SciMLBase: AbstractSciMLOperator +using LinearSolve: ArrayInterface, MKLLUFactorization, @get_cacheval, LinearCache, SciMLBase + +default_alias_A(::MetalLUFactorization, ::Any, ::Any) = false +default_alias_b(::MetalLUFactorization, ::Any, ::Any) = false + +function LinearSolve.init_cacheval(alg::MetalLUFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + ArrayInterface.lu_instance(convert(AbstractMatrix, MtlArray(A))) +end + +function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization; + kwargs...) + A = cache.A + A = convert(AbstractMatrix, A) + if cache.isfresh + cacheval = @get_cacheval(cache, :MetalLUFactorization) + res = lu(MtlArray(A)) + cache.cacheval = fact + cache.isfresh = false + end + y = Array(ldiv!(MtlArray(cache.u), @get_cacheval(cache, :MetalLUFactorization), MtlArray(cache.b))) + SciMLBase.build_linear_solution(alg, y, nothing, cache) +end + +end \ No newline at end of file diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 4149c21cf..95115b32b 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -177,6 +177,7 @@ export MKLPardisoFactorize, MKLPardisoIterate export PardisoJL export MKLLUFactorization export AppleAccelerateLUFactorization +export MetalLUFactorization export OperatorAssumptions, OperatorCondition diff --git a/src/extension_algs.jl b/src/extension_algs.jl index 19ccad02e..8d3e3134e 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -347,4 +347,14 @@ MKLLUFactorization() A wrapper over Intel's Math Kernel Library (MKL). Direct calls to MKL in a way that pre-allocates workspace to avoid allocations and does not require libblastrampoline. """ -struct MKLLUFactorization <: AbstractFactorization end \ No newline at end of file +struct MKLLUFactorization <: AbstractFactorization end + +""" +```julia +MetalLUFactorization() +``` + +A wrapper over Apple's Metal GPU library. Direct calls to Metal in a way that pre-allocates workspace +to avoid allocations and automatically offloads to the GPU. +""" +struct MetalLUFactorization <: AbstractFactorization end \ No newline at end of file From 54b7bec783493d02d30313c3f935ce6680761a96 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 09:05:22 -0400 Subject: [PATCH 02/10] Setup Accelerate and MKL for 32-bit, MKL getrf, fix Metal --- ext/LinearSolveMKLExt.jl | 73 ++++++++++++++++++++++++++++++++++++-- ext/LinearSolveMetalExt.jl | 6 ++-- src/appleaccelerate.jl | 38 ++++++++++++++++++++ 3 files changed, 112 insertions(+), 5 deletions(-) diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index da9f4673d..a505c21aa 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -27,6 +27,63 @@ function getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, BlasInt, min(siz A, ipiv, info[], info #Error code is stored in LU factorization type end +function getrf!(A::AbstractMatrix{<:Float32}; ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))), info = Ref{BlasInt}(), check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1,stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, BlasInt, min(size(A,1),size(A,2))) + end + ccall((@blasfunc(sgetrf_), MKL_jll.libmkl_rt), Cvoid, + (Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, + Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}), + m, n, A, lda, ipiv, info) + chkargsok(info[]) + A, ipiv, info[], info #Error code is stored in LU factorization type +end + +function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("dgetrs_", MKL_jll.libmkl_rt), Cvoid, + (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint}, + Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong), + trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + +function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float32}; info = Ref{Cint}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("sgetrs_", MKL_jll.libmkl_rt), Cvoid, + (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint}, + Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong), + trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + default_alias_A(::MKLLUFactorization, ::Any, ::Any) = false default_alias_b(::MKLLUFactorization, ::Any, ::Any) = false @@ -47,8 +104,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; cache.cacheval = fact cache.isfresh = false end - y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b) - SciMLBase.build_linear_solution(alg, y, nothing, cache) + + A, info = @get_cacheval(cache, :MKLLUFactorization) + LinearAlgebra.require_one_based_indexing(cache.u, cache.b) + m, n = size(A, 1), size(A, 2) + if m > n + Bc = copy(cache.b) + getrs!('N', A.factors, A.ipiv, Bc; info) + return copyto!(cache.u, 1, Bc, 1, n) + else + copyto!(cache.u, cache.b) + getrs!('N', A.factors, A.ipiv, cache.u; info) + end + + SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end end \ No newline at end of file diff --git a/ext/LinearSolveMetalExt.jl b/ext/LinearSolveMetalExt.jl index 3616f9663..4ee07dc39 100644 --- a/ext/LinearSolveMetalExt.jl +++ b/ext/LinearSolveMetalExt.jl @@ -11,7 +11,7 @@ default_alias_b(::MetalLUFactorization, ::Any, ::Any) = false function LinearSolve.init_cacheval(alg::MetalLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.lu_instance(convert(AbstractMatrix, MtlArray(A))) + ArrayInterface.lu_instance(convert(AbstractMatrix, A)) end function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization; @@ -21,10 +21,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::MetalLUFactorization; if cache.isfresh cacheval = @get_cacheval(cache, :MetalLUFactorization) res = lu(MtlArray(A)) - cache.cacheval = fact + cache.cacheval = LU(Array(res.factors), Array{Int}(res.ipiv), res.info) cache.isfresh = false end - y = Array(ldiv!(MtlArray(cache.u), @get_cacheval(cache, :MetalLUFactorization), MtlArray(cache.b))) + y = ldiv!(cache.u, @get_cacheval(cache, :MetalLUFactorization), cache.b) SciMLBase.build_linear_solution(alg, y, nothing, cache) end diff --git a/src/appleaccelerate.jl b/src/appleaccelerate.jl index 9b9cc32db..0b7ea7cae 100644 --- a/src/appleaccelerate.jl +++ b/src/appleaccelerate.jl @@ -44,6 +44,24 @@ function aa_getrf!(A::AbstractMatrix{<:Float64}; ipiv = similar(A, Cint, min(siz A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type end +function aa_getrf!(A::AbstractMatrix{<:Float32}; ipiv = similar(A, Cint, min(size(A,1),size(A,2))), info = Ref{Cint}(), check = false) + require_one_based_indexing(A) + check && chkfinite(A) + chkstride1(A) + m, n = size(A) + lda = max(1,stride(A, 2)) + if isempty(ipiv) + ipiv = similar(A, Cint, min(size(A,1),size(A,2))) + end + + ccall(("sgetrf_", libacc), Cvoid, + (Ref{Cint}, Ref{Cint}, Ptr{Float32}, + Ref{Cint}, Ptr{Cint}, Ptr{Cint}), + m, n, A, lda, ipiv, info) + info[] < 0 && throw(ArgumentError("Invalid arguments sent to LAPACK dgetrf_")) + A, ipiv, BlasInt(info[]), info #Error code is stored in LU factorization type +end + function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}()) require_one_based_indexing(A, ipiv, B) LinearAlgebra.LAPACK.chktrans(trans) @@ -64,6 +82,26 @@ function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::Abst B end +function aa_getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float32}; info = Ref{Cint}()) + require_one_based_indexing(A, ipiv, B) + LinearAlgebra.LAPACK.chktrans(trans) + chkstride1(A, B, ipiv) + n = LinearAlgebra.checksquare(A) + if n != size(B, 1) + throw(DimensionMismatch("B has leading dimension $(size(B,1)), but needs $n")) + end + if n != length(ipiv) + throw(DimensionMismatch("ipiv has length $(length(ipiv)), but needs to be $n")) + end + nrhs = size(B, 2) + ccall(("sgetrs_", libacc), Cvoid, + (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint}, + Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong), + trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1) + LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) + B +end + default_alias_A(::AppleAccelerateLUFactorization, ::Any, ::Any) = false default_alias_b(::AppleAccelerateLUFactorization, ::Any, ::Any) = false From 83e88c2fcafc1a965f2575d49fb13e42c8824537 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 11:51:11 -0400 Subject: [PATCH 03/10] add benchmarks and update solver text --- benchmarks/applelu.jl | 51 +++++++++++++++++++++++++++++ benchmarks/cudalu.jl | 51 +++++++++++++++++++++++++++++ benchmarks/metallu.jl | 51 +++++++++++++++++++++++++++++ docs/src/solvers/solvers.md | 65 ++++++++++++++++++++++++++++++++++--- 4 files changed, 213 insertions(+), 5 deletions(-) create mode 100644 benchmarks/applelu.jl create mode 100644 benchmarks/cudalu.jl create mode 100644 benchmarks/metallu.jl diff --git a/benchmarks/applelu.jl b/benchmarks/applelu.jl new file mode 100644 index 000000000..58ef00b2c --- /dev/null +++ b/benchmarks/applelu.jl @@ -0,0 +1,51 @@ +using BenchmarkTools, Random, VectorizationBase +using LinearAlgebra, LinearSolve, Metal +nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads()) +BLAS.set_num_threads(nc) +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5 + +function luflop(m, n = m; innerflop = 2) + sum(1:min(m, n)) do k + invflop = 1 + scaleflop = isempty((k + 1):m) ? 0 : sum((k + 1):m) + updateflop = isempty((k + 1):n) ? 0 : + sum((k + 1):n) do j + isempty((k + 1):m) ? 0 : sum((k + 1):m) do i + innerflop + end + end + invflop + scaleflop + updateflop + end +end + +algs = [LUFactorization(), GenericLUFactorization(), RFLUFactorization(), AppleAccelerateLUFactorization(), MetalLUFactorization()] +res = [Float32[] for i in 1:length(algs)] + +ns = 4:8:500 +for i in 1:length(ns) + n = ns[i] + @info "$n × $n" + rng = MersenneTwister(123) + global A = rand(rng, Float32, n, n) + global b = rand(rng, Float32, n) + global u0= rand(rng, Float32, n) + + for j in 1:length(algs) + bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A), copy(b); u0 = copy(u0), alias_A=true, alias_b=true)) + push!(res[j], luflop(n) / bt / 1e9) + end +end + +using Plots +__parameterless_type(T) = Base.typename(T).wrapper +parameterless_type(x) = __parameterless_type(typeof(x)) +parameterless_type(::Type{T}) where {T} = __parameterless_type(T) + +p = plot(ns, res[1]; ylabel = "GFLOPs", xlabel = "N", title = "GFLOPs for NxN LU Factorization", label = string(Symbol(parameterless_type(algs[1]))), legend=:outertopright) +for i in 2:length(res) + plot!(p, ns, res[i]; label = string(Symbol(parameterless_type(algs[i])))) +end +p + +savefig("metallubench.png") +savefig("metallubench.pdf") \ No newline at end of file diff --git a/benchmarks/cudalu.jl b/benchmarks/cudalu.jl new file mode 100644 index 000000000..28921edf1 --- /dev/null +++ b/benchmarks/cudalu.jl @@ -0,0 +1,51 @@ +using BenchmarkTools, Random, VectorizationBase +using LinearAlgebra, LinearSolve, CUDA, MKL_jll +nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads()) +BLAS.set_num_threads(nc) +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5 + +function luflop(m, n = m; innerflop = 2) + sum(1:min(m, n)) do k + invflop = 1 + scaleflop = isempty((k + 1):m) ? 0 : sum((k + 1):m) + updateflop = isempty((k + 1):n) ? 0 : + sum((k + 1):n) do j + isempty((k + 1):m) ? 0 : sum((k + 1):m) do i + innerflop + end + end + invflop + scaleflop + updateflop + end +end + +algs = [MKLLUFactorization(), CUDAOffloadFactorization()] +res = [Float32[] for i in 1:length(algs)] + +ns = 200:400:20000 +for i in 1:length(ns) + n = ns[i] + @info "$n × $n" + rng = MersenneTwister(123) + global A = rand(rng, Float32, n, n) + global b = rand(rng, Float32, n) + global u0= rand(rng, Float32, n) + + for j in 1:length(algs) + bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A), copy(b); u0 = copy(u0), alias_A=true, alias_b=true)) + push!(res[j], luflop(n) / bt / 1e9) + end +end + +using Plots +__parameterless_type(T) = Base.typename(T).wrapper +parameterless_type(x) = __parameterless_type(typeof(x)) +parameterless_type(::Type{T}) where {T} = __parameterless_type(T) + +p = plot(ns, res[1]; ylabel = "GFLOPs", xlabel = "N", title = "GFLOPs for NxN LU Factorization", label = string(Symbol(parameterless_type(algs[1]))), legend=:outertopright) +for i in 2:length(res) + plot!(p, ns, res[i]; label = string(Symbol(parameterless_type(algs[i])))) +end +p + +savefig("cudaoffloadlubench.png") +savefig("cudaoffloadlubench.pdf") \ No newline at end of file diff --git a/benchmarks/metallu.jl b/benchmarks/metallu.jl new file mode 100644 index 000000000..a49d2c036 --- /dev/null +++ b/benchmarks/metallu.jl @@ -0,0 +1,51 @@ +using BenchmarkTools, Random, VectorizationBase +using LinearAlgebra, LinearSolve, Metal +nc = min(Int(VectorizationBase.num_cores()), Threads.nthreads()) +BLAS.set_num_threads(nc) +BenchmarkTools.DEFAULT_PARAMETERS.seconds = 0.5 + +function luflop(m, n = m; innerflop = 2) + sum(1:min(m, n)) do k + invflop = 1 + scaleflop = isempty((k + 1):m) ? 0 : sum((k + 1):m) + updateflop = isempty((k + 1):n) ? 0 : + sum((k + 1):n) do j + isempty((k + 1):m) ? 0 : sum((k + 1):m) do i + innerflop + end + end + invflop + scaleflop + updateflop + end +end + +algs = [AppleAccelerateLUFactorization(), MetalLUFactorization()] +res = [Float32[] for i in 1:length(algs)] + +ns = 200:400:20000 +for i in 1:length(ns) + n = ns[i] + @info "$n × $n" + rng = MersenneTwister(123) + global A = rand(rng, Float32, n, n) + global b = rand(rng, Float32, n) + global u0= rand(rng, Float32, n) + + for j in 1:length(algs) + bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A), copy(b); u0 = copy(u0), alias_A=true, alias_b=true)) + push!(res[j], luflop(n) / bt / 1e9) + end +end + +using Plots +__parameterless_type(T) = Base.typename(T).wrapper +parameterless_type(x) = __parameterless_type(typeof(x)) +parameterless_type(::Type{T}) where {T} = __parameterless_type(T) + +p = plot(ns, res[1]; ylabel = "GFLOPs", xlabel = "N", title = "GFLOPs for NxN LU Factorization", label = string(Symbol(parameterless_type(algs[1]))), legend=:outertopright) +for i in 2:length(res) + plot!(p, ns, res[i]; label = string(Symbol(parameterless_type(algs[i])))) +end +p + +savefig("metal_large_lubench.png") +savefig("metal_large_lubench.pdf") \ No newline at end of file diff --git a/docs/src/solvers/solvers.md b/docs/src/solvers/solvers.md index 93cc35df0..f7e52c2a7 100644 --- a/docs/src/solvers/solvers.md +++ b/docs/src/solvers/solvers.md @@ -7,15 +7,37 @@ Solves for ``Au=b`` in the problem defined by `prob` using the algorithm ## Recommended Methods +### Dense Matrices + The default algorithm `nothing` is good for picking an algorithm that will work, but one may need to change this to receive more performance or precision. If more precision is necessary, `QRFactorization()` and `SVDFactorization()` are the best choices, with SVD being the slowest but most precise. -For efficiency, `RFLUFactorization` is the fastest for dense LU-factorizations. -`FastLUFactorization` will be faster than `LUFactorization` which is the Base.LinearAlgebra -(`\` default) implementation of LU factorization. `SimpleLUFactorization` will be fast -on very small matrices. +For efficiency, `RFLUFactorization` is the fastest for dense LU-factorizations until around +150x150 matrices, though this can be dependent on the exact details of the hardware. After this +point, `MKLLUFactorization` is usually faster on most hardware. Note that on Mac computers +that `AppleAccelerateLUFactorization` is generally always the fastest. `LUFactorization` will +use your base system BLAS which can be fast or slow depending on the hardware configuration. +`SimpleLUFactorization` will be fast only on very small matrices but can cut down on compile times. + +For very large dense factorizations, offloading to the GPU can be preferred. Metal.jl can be used +on Mac hardware to offload, and has a cutoff point of being faster at around size 20,000 x 20,000 +matrices (and only supports Float32). `CudaOffloadFactorization` can be more efficient at a +much smaller cutoff, possibly around size 1,000 x 1,000 matrices, though this is highly dependent +on the chosen GPU hardware. `CudaOffloadFactorization` requires a CUDA-compatible NVIDIA GPU. +CUDA offload supports Float64 but most consumer GPU hardware will be much faster on Float32 +(many are >32x faster for Float32 operations than Float64 operations) and thus for most hardware +this is only recommended for Float32 matrices. + +!!! note + + Performance details for dense LU-factorizations can be highly dependent on the hardware configuration. + For details see [this issue](https://github.com/SciML/LinearSolve.jl/issues/357). + If one is looking to best optimize their system, we suggest running the performance + tuning benchmark. + +### Sparse Matrices For sparse LU-factorizations, `KLUFactorization` if there is less structure to the sparsity pattern and `UMFPACKFactorization` if there is more structure. @@ -31,12 +53,25 @@ As sparse matrices get larger, iterative solvers tend to get more efficient than factorization methods if a lower tolerance of the solution is required. Krylov.jl generally outperforms IterativeSolvers.jl and KrylovKit.jl, and is compatible -with CPUs and GPUs, and thus is the generally preferred form for Krylov methods. +with CPUs and GPUs, and thus is the generally preferred form for Krylov methods. The +choice of Krylov method should be the one most constrained to the type of operator one +has, for example if positive definite then `Krylov_CG()`, but if no good properties then +use `Krylov_GMRES()`. Finally, a user can pass a custom function for handling the linear solve using `LinearSolveFunction()` if existing solvers are not optimally suited for their application. The interface is detailed [here](@ref custom). +### Lazy SciMLOperators + +If the linear operator is given as a lazy non-concrete operator, such as a `FunctionOperator`, +then using a Krylov method is preferred in order to not concretize the matrix. +Krylov.jl generally outperforms IterativeSolvers.jl and KrylovKit.jl, and is compatible +with CPUs and GPUs, and thus is the generally preferred form for Krylov methods. The +choice of Krylov method should be the one most constrained to the type of operator one +has, for example if positive definite then `Krylov_CG()`, but if no good properties then +use `Krylov_GMRES()`. + ## Full List of Methods ### RecursiveFactorization.jl @@ -121,6 +156,26 @@ KrylovJL MKLLUFactorization ``` +### AppleAccelerate.jl + +!!! note + + Using this solver requires a Mac with Apple Accelerate. This should come standard in most "modern" Mac computers. + +```@docs +AppleAccelerateLUFactorization +``` + +### Metal.jl + +!!! note + + Using this solver requires adding the package Metal.jl, i.e. `using Metal`. This package is only compatible with Mac M-Series computers with a Metal-compatible GPU. + +```@docs +MetalLUFactorization +``` + ### Pardiso.jl !!! note From 342307e4cc0191079491d795736d35167e71876c Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 12:37:21 -0400 Subject: [PATCH 04/10] don't run out of memory --- benchmarks/metallu.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/metallu.jl b/benchmarks/metallu.jl index a49d2c036..a9614e7a4 100644 --- a/benchmarks/metallu.jl +++ b/benchmarks/metallu.jl @@ -21,7 +21,7 @@ end algs = [AppleAccelerateLUFactorization(), MetalLUFactorization()] res = [Float32[] for i in 1:length(algs)] -ns = 200:400:20000 +ns = 200:600:15000 for i in 1:length(ns) n = ns[i] @info "$n × $n" @@ -32,6 +32,7 @@ for i in 1:length(ns) for j in 1:length(algs) bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy(A), copy(b); u0 = copy(u0), alias_A=true, alias_b=true)) + GC.gc() push!(res[j], luflop(n) / bt / 1e9) end end From 78158af31fcc28a651a4088ef5ba8bfc35563286 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 12:54:38 -0400 Subject: [PATCH 05/10] use BlasInt everywhere with MKL --- ext/LinearSolveMKLExt.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index a505c21aa..c209c2dcd 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -44,7 +44,7 @@ function getrf!(A::AbstractMatrix{<:Float32}; ipiv = similar(A, BlasInt, min(siz A, ipiv, info[], info #Error code is stored in LU factorization type end -function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float64}; info = Ref{Cint}()) +function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{<:Float64}; info = Ref{BlasInt}()) require_one_based_indexing(A, ipiv, B) LinearAlgebra.LAPACK.chktrans(trans) chkstride1(A, B, ipiv) @@ -57,14 +57,14 @@ function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float64}, ipiv::Abstrac end nrhs = size(B, 2) ccall(("dgetrs_", MKL_jll.libmkl_rt), Cvoid, - (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float64}, Ref{Cint}, - Ptr{Cint}, Ptr{Float64}, Ref{Cint}, Ptr{Cint}, Clong), + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float64}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{Float64}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1) LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) B end -function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{Cint}, B::AbstractVecOrMat{<:Float32}; info = Ref{Cint}()) +function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::AbstractVector{BlasInt}, B::AbstractVecOrMat{<:Float32}; info = Ref{BlasInt}()) require_one_based_indexing(A, ipiv, B) LinearAlgebra.LAPACK.chktrans(trans) chkstride1(A, B, ipiv) @@ -77,8 +77,8 @@ function getrs!(trans::AbstractChar, A::AbstractMatrix{<:Float32}, ipiv::Abstrac end nrhs = size(B, 2) ccall(("sgetrs_", MKL_jll.libmkl_rt), Cvoid, - (Ref{UInt8}, Ref{Cint}, Ref{Cint}, Ptr{Float32}, Ref{Cint}, - Ptr{Cint}, Ptr{Float32}, Ref{Cint}, Ptr{Cint}, Clong), + (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{Float32}, Ref{BlasInt}, + Ptr{BlasInt}, Ptr{Float32}, Ref{BlasInt}, Ptr{BlasInt}, Clong), trans, n, size(B,2), A, max(1,stride(A,2)), ipiv, B, max(1,stride(B,2)), info, 1) LinearAlgebra.LAPACK.chklapackerror(BlasInt(info[])) B From ffc057476117696e15c050b1d234b2ecfd4b9252 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 13:26:02 -0400 Subject: [PATCH 06/10] Guess it's transposed? --- benchmarks/cudalu.jl | 2 +- ext/LinearSolveMKLExt.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmarks/cudalu.jl b/benchmarks/cudalu.jl index 28921edf1..e7b38856b 100644 --- a/benchmarks/cudalu.jl +++ b/benchmarks/cudalu.jl @@ -21,7 +21,7 @@ end algs = [MKLLUFactorization(), CUDAOffloadFactorization()] res = [Float32[] for i in 1:length(algs)] -ns = 200:400:20000 +ns = 200:400:10000 for i in 1:length(ns) n = ns[i] @info "$n × $n" diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index c209c2dcd..d2b5b6d9d 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -110,11 +110,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; m, n = size(A, 1), size(A, 2) if m > n Bc = copy(cache.b) - getrs!('N', A.factors, A.ipiv, Bc; info) + getrs!('T', A.factors, A.ipiv, Bc; info) return copyto!(cache.u, 1, Bc, 1, n) else copyto!(cache.u, cache.b) - getrs!('N', A.factors, A.ipiv, cache.u; info) + getrs!('T', A.factors, A.ipiv, cache.u; info) end SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) From 30c8fcd675ee4309b2bad9b73dec28f37335010e Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 13:54:20 -0400 Subject: [PATCH 07/10] revert T --- ext/LinearSolveMKLExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index d2b5b6d9d..c209c2dcd 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -110,11 +110,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; m, n = size(A, 1), size(A, 2) if m > n Bc = copy(cache.b) - getrs!('T', A.factors, A.ipiv, Bc; info) + getrs!('N', A.factors, A.ipiv, Bc; info) return copyto!(cache.u, 1, Bc, 1, n) else copyto!(cache.u, cache.b) - getrs!('T', A.factors, A.ipiv, cache.u; info) + getrs!('N', A.factors, A.ipiv, cache.u; info) end SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) From 789ae07513073e5e2cf8706796e59eea0e77e0f2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 13:56:23 -0400 Subject: [PATCH 08/10] remove MKL getrs! for now --- ext/LinearSolveMKLExt.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ext/LinearSolveMKLExt.jl b/ext/LinearSolveMKLExt.jl index c209c2dcd..1ed917f58 100644 --- a/ext/LinearSolveMKLExt.jl +++ b/ext/LinearSolveMKLExt.jl @@ -104,7 +104,11 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; cache.cacheval = fact cache.isfresh = false end - + + y = ldiv!(cache.u, @get_cacheval(cache, :MKLLUFactorization)[1], cache.b) + SciMLBase.build_linear_solution(alg, y, nothing, cache) + + #= A, info = @get_cacheval(cache, :MKLLUFactorization) LinearAlgebra.require_one_based_indexing(cache.u, cache.b) m, n = size(A, 1), size(A, 2) @@ -118,6 +122,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::MKLLUFactorization; end SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) + =# end end \ No newline at end of file From 4387dcc4b0d668fcae93f0540fd21f86a623165b Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 14:04:21 -0400 Subject: [PATCH 09/10] don't test resolve on Metal --- test/resolve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/resolve.jl b/test/resolve.jl index d622a8865..d49d5b1b6 100644 --- a/test/resolve.jl +++ b/test/resolve.jl @@ -2,7 +2,7 @@ using LinearSolve, LinearAlgebra, SparseArrays, InteractiveUtils, Test for alg in subtypes(LinearSolve.AbstractFactorization) @show alg - if !(alg in [DiagonalFactorization, CudaOffloadFactorization, AppleAccelerateLUFactorization]) && + if !(alg in [DiagonalFactorization, CudaOffloadFactorization, AppleAccelerateLUFactorization, MetalLUFactorization]) && (!(alg == AppleAccelerateLUFactorization) || LinearSolve.appleaccelerate_isavailable()) A = [1.0 2.0; 3.0 4.0] From c3a93d961636874c2263523ede5a5b17e8d07971 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 9 Aug 2023 16:56:48 -0400 Subject: [PATCH 10/10] Update benchmarks/cudalu.jl Co-authored-by: Christian Guinard --- benchmarks/cudalu.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/cudalu.jl b/benchmarks/cudalu.jl index e7b38856b..b0186f2ec 100644 --- a/benchmarks/cudalu.jl +++ b/benchmarks/cudalu.jl @@ -18,7 +18,7 @@ function luflop(m, n = m; innerflop = 2) end end -algs = [MKLLUFactorization(), CUDAOffloadFactorization()] +algs = [MKLLUFactorization(), CudaOffloadFactorization()] res = [Float32[] for i in 1:length(algs)] ns = 200:400:10000