Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce unnecessary allocations and reuse code #297

Merged
merged 25 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolve"
uuid = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
authors = ["SciML"]
version = "2.9.0"
version = "2.10.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -16,6 +16,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Expand All @@ -25,7 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -42,8 +43,8 @@ NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
ArrayInterface = "6.0.24, 7"
Aqua = "0.8"
ArrayInterface = "6.0.24, 7"
BandedMatrices = "1"
BenchmarkTools = "1"
ConcreteStructs = "0.2"
Expand All @@ -59,6 +60,7 @@ LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
MaybeInplace = "0.1"
NaNMath = "1"
NonlinearProblemLibrary = "0.1"
Pkg = "1"
Expand All @@ -72,9 +74,8 @@ SciMLBase = "2.9"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.12"
SparseDiffTools = "2.14"
StaticArrays = "1"
StaticArraysCore = "1.4"
Symbolics = "5"
Test = "1"
UnPack = "1.0"
Expand Down
80 changes: 64 additions & 16 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,26 @@
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using DiffEqBase,
LazyArrays, LinearAlgebra, LinearSolve, Printf, SparseArrays,
SparseDiffTools
using FastBroadcast: @..
import ArrayInterface: restructure
using ADTypes, DiffEqBase, LazyArrays, LineSearches, LinearAlgebra, LinearSolve, Printf,
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools, StaticArrays

import ADTypes: AbstractFiniteDifferencesMode
import ArrayInterface: undefmatrix,
import ArrayInterface: undefmatrix, restructure, can_setindex,
matrix_colors, parameterless_type, ismutable, issingular, fast_scalar_indexing
import ConcreteStructs: @concrete
import EnumX: @enumx
import FastBroadcast: @..
import FiniteDiff
import ForwardDiff
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
import MaybeInplace: setindex_trait, @bb, CanSetindex, CannotSetindex
import RecursiveArrayTools: ArrayPartition,
AbstractVectorOfArray, recursivecopy!, recursivefill!
import SciMLBase: AbstractNonlinearAlgorithm, NLStats, _unwrap_val, has_jac, isinplace
import SciMLOperators: FunctionOperator
import StaticArraysCore: StaticArray, SVector, SArray, MArray
import StaticArrays: StaticArray, SVector, SArray, MArray, Size, SMatrix, MMatrix
import UnPack: @unpack

using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
end

@reexport using ADTypes, LineSearches, SciMLBase, SimpleNonlinearSolve
Expand All @@ -52,16 +50,65 @@

isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip

function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(cache);
p = cache.p, abstol = cache.abstol, reltol = cache.reltol,
maxiters = cache.maxiters, alias_u0 = false, termination_condition = missing,
kwargs...) where {iip}
cache.p = p
if iip
recursivecopy!(get_u(cache), u0)
cache.f(get_fu(cache), get_u(cache), p)
else
cache.u = __maybe_unaliased(u0, alias_u0)
set_fu!(cache, cache.f(cache.u, p))
end

reset!(cache.trace)

# Some algorithms store multiple termination caches
if hasfield(typeof(cache), :tc_cache)
# TODO: We need an efficient way to reset this upstream
tc = termination_condition === missing ? get_termination_mode(cache.tc_cache) :
termination_condition
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, get_fu(cache),
get_u(cache), tc)
cache.tc_cache = tc_cache
end

if hasfield(typeof(cache), :ls_cache)
# TODO: A more efficient way to do this
cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
get_u(cache), p, get_fu(cache), Val(iip))
end

hasfield(typeof(cache), :uf) && (cache.uf.p = p)

cache.abstol = abstol
cache.reltol = reltol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default

__reinit_internal!(cache; u0, p, abstol, reltol, maxiters, alias_u0,
termination_condition, kwargs...)

return cache
end

__reinit_internal!(::AbstractNonlinearSolveCache; kwargs...) = nothing

function Base.show(io::IO, alg::AbstractNonlinearSolveAlgorithm)
str = "$(nameof(typeof(alg)))("
modifiers = String[]
if _getproperty(alg, Val(:ad)) !== nothing
if __getproperty(alg, Val(:ad)) !== nothing

Check warning on line 105 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L105

Added line #L105 was not covered by tests
push!(modifiers, "ad = $(nameof(typeof(alg.ad)))()")
end
if _getproperty(alg, Val(:linsolve)) !== nothing
if __getproperty(alg, Val(:linsolve)) !== nothing

Check warning on line 108 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L108

Added line #L108 was not covered by tests
push!(modifiers, "linsolve = $(nameof(typeof(alg.linsolve)))()")
end
if _getproperty(alg, Val(:linesearch)) !== nothing
if __getproperty(alg, Val(:linesearch)) !== nothing

Check warning on line 111 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L111

Added line #L111 was not covered by tests
ls = alg.linesearch
if ls isa LineSearch
ls.method !== nothing &&
Expand All @@ -70,7 +117,7 @@
push!(modifiers, "linesearch = $(nameof(typeof(alg.linesearch)))()")
end
end
if _getproperty(alg, Val(:radius_update_scheme)) !== nothing
if __getproperty(alg, Val(:radius_update_scheme)) !== nothing

Check warning on line 120 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L120

Added line #L120 was not covered by tests
push!(modifiers, "radius_update_scheme = $(alg.radius_update_scheme)")
end
str = str * join(modifiers, ", ")
Expand All @@ -87,8 +134,9 @@
function not_terminated(cache::AbstractNonlinearSolveCache)
return !cache.force_stop && cache.stats.nsteps < cache.maxiters
end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)

get_fu(cache::AbstractNonlinearSolveCache) = cache.fu
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u
SciMLBase.set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

Expand All @@ -107,7 +155,7 @@
end
end

trace = _getproperty(cache, Val{:trace}())
trace = __getproperty(cache, Val{:trace}())
if trace !== nothing
update_trace!(trace, cache.stats.nsteps, get_u(cache), get_fu(cache), nothing,
nothing, nothing; last = Val(true))
Expand Down
136 changes: 34 additions & 102 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ end
f
alg
u
u_prev
u_cache
du
fu
fu2
fu_cache
dfu
p
J⁻¹
J⁻¹₂
J⁻¹df
J⁻¹dfu
force_stop::Bool
resets::Int
max_resets::Int
Expand All @@ -57,144 +56,77 @@ end
trace
end

get_fu(cache::GeneralBroydenCache) = cache.fu
set_fu!(cache::GeneralBroydenCache, fu) = (cache.fu = fu)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {uType, iip, F}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
u = __maybe_unaliased(u0, alias_u0)
fu = evaluate_f(prob, u)
du = _mutable_zero(u)
@bb du = copy(u)
J⁻¹ = __init_identity_jacobian(u, fu)
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(real(eltype(u)))) :
alg.reset_tolerance
reset_check = x -> abs(x) ≤ reset_tolerance

@bb u_cache = copy(u)
@bb fu_cache = copy(fu)
@bb dfu = similar(fu)
@bb J⁻¹dfu = similar(u)

abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fu, u,
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, J⁻¹, du; uses_jac_inverse = Val(true),
kwargs...)

return GeneralBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
return GeneralBroydenCache{iip}(f, alg, u, u_cache, du, fu, fu_cache, dfu, p,
J⁻¹, J⁻¹dfu, false, 0, alg.max_resets, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), tc_cache, trace)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂ = cache
T = eltype(u)

mul!(_vec(du), J⁻¹, _vec(fu))
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(fu2, u, p)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), J⁻¹, du, α)

check_and_update!(cache, fu2, u, u_prev)
cache.stats.nf += 1

cache.force_stop && return nothing

# Update the inverse jacobian
dfu .= fu2 .- fu

if all(cache.reset_check, du) || all(cache.reset_check, dfu)
if cache.resets ≥ cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
fill!(J⁻¹, 0)
J⁻¹[diagind(J⁻¹)] .= T(1)
cache.resets += 1
else
du .*= -1
mul!(_vec(J⁻¹df), J⁻¹, _vec(dfu))
mul!(J⁻¹₂, _vec(du)', J⁻¹)
denom = dot(du, J⁻¹df)
du .= (du .- J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache

function perform_step!(cache::GeneralBroydenCache{iip}) where {iip}
T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * _vec(cache.fu))
@bb cache.du = cache.J⁻¹ × vec(cache.fu)
α = perform_linesearch!(cache.ls_cache, cache.u, cache.du)
cache.u = cache.u .- α * cache.du
cache.fu2 = f(cache.u, p)
@bb axpy!(-α, cache.du, cache.u)

update_trace_with_invJ!(cache.trace, cache.stats.nsteps + 1, get_u(cache),
get_fu(cache), cache.J⁻¹, cache.du, α)
evaluate_f(cache, cache.u, cache.p)

check_and_update!(cache, cache.fu2, cache.u, cache.u_prev)
cache.stats.nf += 1
update_trace!(cache, α)
check_and_update!(cache, cache.fu, cache.u, cache.u_cache)

cache.force_stop && return nothing

# Update the inverse jacobian
cache.dfu = cache.fu2 .- cache.fu
@bb @. cache.dfu = cache.fu - cache.fu_cache

if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)
if cache.resets ≥ cache.max_resets
cache.retcode = ReturnCode.ConvergenceFailure
cache.force_stop = true
return nothing
end
cache.J⁻¹ = __init_identity_jacobian(cache.u, cache.fu)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets += 1
else
cache.du = -cache.du
cache.J⁻¹df = _restructure(cache.J⁻¹df, cache.J⁻¹ * _vec(cache.dfu))
cache.J⁻¹₂ = _vec(cache.du)' * cache.J⁻¹
denom = dot(cache.du, cache.J⁻¹df)
cache.du = (cache.du .- cache.J⁻¹df) ./ ifelse(iszero(denom), T(1e-5), denom)
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
@bb cache.du .*= -1
@bb cache.J⁻¹dfu = cache.J⁻¹ × vec(cache.dfu)
@bb cache.u_cache = transpose(cache.J⁻¹) × vec(cache.du)
denom = dot(cache.du, cache.J⁻¹dfu)
@bb @. cache.du = (cache.du - cache.J⁻¹dfu) / ifelse(iszero(denom), T(1e-5), denom)
@bb cache.J⁻¹ += vec(cache.du) × transpose(_vec(cache.u_cache))
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

@bb copyto!(cache.fu_cache, cache.fu)
@bb copyto!(cache.u_cache, cache.u)

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, reltol = cache.reltol, maxiters = cache.maxiters,
termination_condition = get_termination_mode(cache.tc_cache)) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
cache.f(cache.fu, cache.u, p)
else
# don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter
cache.u = u0
cache.fu = cache.f(cache.u, p)
end

reset!(cache.trace)
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, cache.fu, cache.u,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.tc_cache = tc_cache
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
function __reinit_internal!(cache::GeneralBroydenCache; kwargs...)
cache.J⁻¹ = __reinit_identity_jacobian!!(cache.J⁻¹)
cache.resets = 0
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
return nothing
end
Loading
Loading