diff --git a/ext/LinearSolveHYPREExt.jl b/ext/LinearSolveHYPREExt.jl index 279aba75..4dd06af1 100644 --- a/ext/LinearSolveHYPREExt.jl +++ b/ext/LinearSolveHYPREExt.jl @@ -86,12 +86,13 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, assumptions) Tc = typeof(cacheval) isfresh = true + precsisfresh = false cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(__issquare(assumptions)), typeof(sensealg) - }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + }(A, b, u0, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache end diff --git a/ext/LinearSolveIterativeSolversExt.jl b/ext/LinearSolveIterativeSolversExt.jl index cb458964..198cc0a5 100644 --- a/ext/LinearSolveIterativeSolversExt.jl +++ b/ext/LinearSolveIterativeSolversExt.jl @@ -90,6 +90,12 @@ function LinearSolve.init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, max end function SciMLBase.solve!(cache::LinearCache, alg::IterativeSolversJL; kwargs...) + if cache.precsisfresh && !isnothing(alg.precs) + Pl, Pr = alg.precs(cache.Pl, cache.Pr) + cache.Pl = Pl + cache.Pr = Pr + cache.precsisfresh = false + end if cache.isfresh || !(alg isa IterativeSolvers.GMRESIterable) solver = LinearSolve.init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, diff --git a/src/common.jl b/src/common.jl index f212fe25..eb1adacd 100644 --- a/src/common.jl +++ b/src/common.jl @@ -73,6 +73,7 @@ mutable struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq, S} alg::Talg cacheval::Tc # store alg cache here isfresh::Bool # false => cacheval is set wrt A, true => update cacheval wrt A + precsisfresh::Bool # false => PR,PL is set wrt A, true => update PR,PL wrt A Pl::Tl # preconditioners Pr::Tr abstol::Ttol @@ -85,18 +86,10 @@ end function Base.setproperty!(cache::LinearCache, name::Symbol, x) if name === :A - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(x, cache.p) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end setfield!(cache, :isfresh, true) + setfield!(cache, :precsisfresh, true) elseif name === :p - if hasproperty(cache.alg, :precs) && !isnothing(cache.alg.precs) - Pl, Pr = cache.alg.precs(cache.A, x) - setfield!(cache, :Pl, Pl) - setfield!(cache, :Pr, Pr) - end + setfield!(cache, :precsisfresh, true) elseif name === :b # In case there is something that needs to be done when b is updated update_cacheval!(cache, :b, x) @@ -208,11 +201,12 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) isfresh = true + precsisfresh = false Tc = typeof(cacheval) cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}(A, b, u0_, p, alg, cacheval, isfresh, precsisfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) return cache end @@ -226,24 +220,20 @@ function SciMLBase.reinit!(cache::LinearCache; reinit_cache = false,) (; alg, cacheval, abstol, reltol, maxiters, verbose, assumptions, sensealg) = cache - precs = (hasproperty(alg, :precs) && !isnothing(alg.precs)) ? alg.precs : DEFAULT_PRECS - Pl, Pr = if isnothing(A) || isnothing(p) - if isnothing(A) - A = cache.A - end - if isnothing(p) - p = cache.p - end - precs(A, p) - else - (cache.Pl, cache.Pr) - end - isfresh = true + isfresh = isnothing(A) + precsisfresh = isfresh || isnothing(p) + + A = isnothing(A) ? cache.A : A + b = isnothing(b) ? cache.b : b + u = isnothing(u) ? cache.u : u + p = isnothing(p) ? cache.p : p + Pl = cache.Pl + Pr = cache.Pr if reinit_cache return LinearCache{typeof(A), typeof(b), typeof(u), typeof(p), typeof(alg), typeof(cacheval), typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq), - typeof(sensealg)}(A, b, u, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, + typeof(sensealg)}(A, b, u, p, alg, cacheval, precsisfresh, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions, sensealg) else cache.A = A @@ -253,6 +243,7 @@ function SciMLBase.reinit!(cache::LinearCache; cache.Pl = Pl cache.Pr = Pr cache.isfresh = true + cache.isfresh = true end end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 16a50a27..f487463c 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -225,6 +225,12 @@ function init_cacheval(alg::KrylovJL, A, b, u, Pl, Pr, maxiters::Int, abstol, re end function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) + if cache.precsisfresh && !isnothing(alg.precs) + Pl, Pr = alg.precs(cache.A, cache.p) + cache.Pl = Pl + cache.Pr = Pr + cache.precsisfresh = false + end if cache.isfresh solver = init_cacheval(alg, cache.A, cache.b, cache.u, cache.Pl, cache.Pr, cache.maxiters, cache.abstol, cache.reltol, cache.verbose,