Skip to content

Commit

Permalink
Line Search for Gauss Newton
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2023
1 parent 042ab37 commit 13b27bf
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
23 changes: 16 additions & 7 deletions src/gaussnewton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing,
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)
An advanced GaussNewton implementation with support for efficient handling of sparse
Expand Down Expand Up @@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref),
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
!!! warning
Expand All @@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems.
ad::AD
linsolve
precs
linesearch
end

function set_ad(alg::GaussNewton{CJ}, ad) where {CJ}
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs)
return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch)

Check warning on line 50 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L50

Added line #L50 was not covered by tests
end

function GaussNewton(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch)

Check warning on line 57 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L56-L57

Added lines #L56 - L57 were not covered by tests
end

@concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip}
Expand Down Expand Up @@ -78,6 +83,7 @@ end
stats::NLStats
tc_cache_1
tc_cache_2
ls_cache
end

function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton,
Expand Down Expand Up @@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::

return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf,
linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default,
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2)
abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2,
init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip)))
end

function perform_step!(cache::GaussNewtonCache{true})
Expand All @@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true})
linu = _vec(du), p, reltol = cache.abstol)
end
cache.linsolve = linres.cache
@. u = u - du
α = perform_linesearch!(cache.ls_cache, u, du)
_axpy!(-α, du, u)

Check warning on line 139 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
f(cache.fu_new, u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down Expand Up @@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false})
end
cache.linsolve = linres.cache
end
cache.u = @. u - cache.du # `u` might not support mutation
α = perform_linesearch!(cache.ls_cache, u, cache.du)
cache.u = @. u - α * cache.du # `u` might not support mutation

Check warning on line 181 in src/gaussnewton.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussnewton.jl#L180-L181

Added lines #L180 - L181 were not covered by tests
cache.fu_new = f(cache.u, p)

check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev)
Expand Down
1 change: 1 addition & 0 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
end

function g!(u, fu)
# FIXME: Upstream patch to allow non-square Jacobians
op = VecJac((args...) -> f(args..., p), u; autodiff)
if iip
mul!(g₀, op, fu)
Expand Down
2 changes: 1 addition & 1 deletion src/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(),
precs = DEFAULT_PRECS, adkwargs...)
An advanced NewtonRaphson implementation with support for efficient handling of sparse
Expand Down
18 changes: 10 additions & 8 deletions test/nonlinear_least_squares.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function;
resid_prototype = zero(y_target)), θ_init, x)

nlls_problems = [prob_oop, prob_iip]
solvers = [
GaussNewton(),
GaussNewton(; linsolve = LUFactorization()),
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
]
solvers = vec(Any[GaussNewton(; linsolve, linesearch)
for linsolve in [nothing, LUFactorization()],
linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]])
append!(solvers,
[
LevenbergMarquardt(),
LevenbergMarquardt(; linsolve = LUFactorization()),
LeastSquaresOptimJL(:lm),
LeastSquaresOptimJL(:dogleg),
])

for prob in nlls_problems, solver in solvers
@time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)
Expand Down

0 comments on commit 13b27bf

Please sign in to comment.