Skip to content

Commit

Permalink
Merge pull request #297 from avik-pal/ap/inplace
Browse files Browse the repository at this point in the history
Reduce unnecessary allocations and reuse code
  • Loading branch information
ChrisRackauckas authored Dec 7, 2023
2 parents a39130b + 3b52a5a commit 5b14b20
Show file tree
Hide file tree
Showing 20 changed files with 1,138 additions and 1,712 deletions.
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.9.0"
version = "2.10.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -25,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
Aqua = "0.8"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
BenchmarkTools = "1"
ConcreteStructs = "0.2"
Expand All @@ -59,6 +60,7 @@ LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
MaybeInplace = "0.1"
NaNMath = "1"
NonlinearProblemLibrary = "0.1"
Pkg = "1"
Expand All @@ -72,9 +74,8 @@ SciMLBase = "2.9"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
SparseDiffTools = "2.14"
StaticArrays = "1"
StaticArraysCore = "1.4"
Symbolics = "5"
Test = "1"
UnPack = "1.0"
Expand Down
80 changes: 64 additions & 16 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,26 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools
using FastBroadcast: @..
import ArrayInterface: restructure
using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix,
import ArrayInterface: undefmatrix, restructure, can_setindex,
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
import ConcreteStructs: @concrete
import EnumX: @enumx
import FastBroadcast: @..
import FiniteDiff
import ForwardDiff
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
import UnPack: @unpack

using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
end

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
Expand All @@ -52,16 +50,65 @@ abstract type AbstractNonlinearSolveCache{iip} end

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
maxiters = cache.maxiters, alias_u0 = false, termination_condition = missing,
kwargs...) where {iip}
cache.p = p
if iip
recursivecopy!(get_u(cache), u0)
cache.f(get_fu(cache), get_u(cache), p)
else
cache.u = __maybe_unaliased(u0, alias_u0)
set_fu!(cache, cache.f(cache.u, p))
end

reset!(cache.trace)

# Some algorithms store multiple termination caches
if hasfield(typeof(cache), :tc_cache)
# TODO: We need an efficient way to reset this upstream
tc = termination_condition === missing ? get_termination_mode(cache.tc_cache) :
termination_condition
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
get_u(cache), tc)
cache.tc_cache = tc_cache
end

if hasfield(typeof(cache), :ls_cache)
# TODO: A more efficient way to do this
cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
get_u(cache), p, get_fu(cache), Val(iip))
end

hasfield(typeof(cache), :uf) && (cache.uf.p = p)

cache.abstol = abstol
cache.reltol = reltol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default

__reinit_internal!(cache; u0, p, abstol, reltol, maxiters, alias_u0,
termination_condition, kwargs...)

return cache
end

__reinit_internal!(::AbstractNonlinearSolveCache; kwargs...) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
modifiers = String[]
if _getproperty(alg, Val(:ad)) !== nothing
if __getproperty(alg, Val(:ad)) !== nothing
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
end
if _getproperty(alg, Val(:linsolve)) !== nothing
if __getproperty(alg, Val(:linsolve)) !== nothing
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
end
if _getproperty(alg, Val(:linesearch)) !== nothing
if __getproperty(alg, Val(:linesearch)) !== nothing
ls = alg.linesearch
if ls isa LineSearch
ls.method !== nothing &&
Expand All @@ -70,7 +117,7 @@ function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
end
end
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
if __getproperty(alg, Val(:radius_update_scheme)) !== nothing
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
end
str = str * join(modifiers, ", ")
Expand All @@ -87,8 +134,9 @@ end
function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)

get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

Expand All @@ -107,7 +155,7 @@ function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
end
end

trace = _getproperty(cache, Val{:trace}())
trace = __getproperty(cache, Val{:trace}())
if trace !== nothing
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
nothing, nothing; last = Val(true))
Expand Down
136 changes: 34 additions & 102 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ end
f
alg
u
u_prev
u_cache
du
fu
fu2
fu_cache
dfu
p
J⁻¹
J⁻¹₂
J⁻¹df
J⁻¹dfu
force_stop::Bool
resets::Int
max_resets::Int
Expand All @@ -57,144 +56,77 @@ end
trace
end

get_fu(cache::GeneralBroydenCache) = cache.fu
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {uType, iip, F}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u = __maybe_unaliased(u0, alias_u0)
fu = evaluate_f(prob, u)
du = _mutable_zero(u)
@bb du = copy(u)
J⁻¹ = __init_identity_jacobian(u, fu)
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance

@bb u_cache = copy(u)
@bb fu_cache = copy(fu)
@bb dfu = similar(fu)
@bb J⁻¹dfu = similar(u)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
kwargs...)

return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
T = eltype(u)

mul!(_vec(du), J⁻¹, _vec(fu))
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(fu2, u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), J⁻¹, du, α)

check_and_update!(cache, fu2, u, u_prev)
cache.stats.nf += 1

cache.force_stop && return nothing

# Update the inverse jacobian
dfu .= fu2 .- fu

if all(cache.reset_check, du) || all(cache.reset_check, dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
fill!(J⁻¹, 0)
J⁻¹[diagind(J⁻¹)] .= T(1)
cache.resets += 1
else
du .*= -1
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
mul!(J⁻¹₂, _vec(du)', J⁻¹)
denom = dot(du, J⁻¹df)
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache

function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
cache.u = cache.u .- α * cache.du
cache.fu2 = f(cache.u, p)
@bb axpy!(-α, cache.du, cache.u)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), cache.J⁻¹, cache.du, α)
evaluate_f(cache, cache.u, cache.p)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.stats.nf += 1
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

cache.force_stop && return nothing

# Update the inverse jacobian
cache.dfu = cache.fu2 .- cache.fu
@bb @. cache.dfu = cache.fu - cache.fu_cache

if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
if cache.resets cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets += 1
else
cache.du = -cache.du
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
denom = dot(cache.du, cache.J⁻¹df)
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
@bb cache.du .*= -1
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
@bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du)
denom = dot(cache.du, cache.J⁻¹dfu)
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom)
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache))
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

@bb copyto!(cache.fu_cache, cache.fu)
@bb copyto!(cache.u_cache, cache.u)

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
function __reinit_internal!(cache::GeneralBroydenCache; kwargs...)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets = 0
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
return nothing
end
Loading

0 comments on commit 5b14b20

Please sign in to comment.