Skip to content

Commit

Permalink
Merge pull request #268 from avik-pal/ap/gn_linesearch
Browse files Browse the repository at this point in the history
Gauss Newton with Line Search
  • Loading branch information
ChrisRackauckas authored Nov 5, 2023
2 parents b8d43a3 + 77be3a2 commit 4f1676c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Reexport = "0.2, 1"
SciMLBase = "2.4"
SimpleNonlinearSolve = "0.1.23"
SparseArrays = "1.9"
SparseDiffTools = "2.9"
SparseDiffTools = "2.11"
StaticArraysCore = "1.4"
UnPack = "1.0"
Zygote = "0.6"
Expand Down
23 changes: 16 additions & 7 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
Expand Down Expand Up @@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
!!! warning
Expand All @@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems.
ad::AD
linsolve
precs
linesearch
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -78,6 +83,7 @@ end
stats::NLStats
tc_cache_1
tc_cache_2
ls_cache
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
Expand Down Expand Up @@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2)
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
end

function perform_step!(cache::GaussNewtonCache{true})
Expand All @@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true})
linu = _vec(du), p, reltol = cache.abstol)
end
cache.linsolve = linres.cache
@. u = u - du
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)
f(cache.fu_new, u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down Expand Up @@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false})
end
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
α = perform_linesearch!(cache.ls_cache, u, cache.du)
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu_new = f(cache.u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down
2 changes: 1 addition & 1 deletion src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
end

function g!(u, fu)
op = VecJac((args...) -> f(args..., p), u; autodiff)
op = VecJac(f, u, p; fu = fu1, autodiff)
if iip
mul!(g₀, op, fu)
return g₀
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)
An advanced NewtonRaphson implementation with support for efficient handling of sparse
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ function __get_concrete_algorithm(alg, prob)
use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff()
else
(use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(;
tag = NonlinearSolveTag())
tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0)))
end
return set_ad(alg, ad)
end
Expand Down
18 changes: 10 additions & 8 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = LUFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
for linsolve in [nothing, LUFactorization()],
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
append!(solvers,
[
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
])

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down

0 comments on commit 4f1676c

Please sign in to comment.