diff --git a/Project.toml b/Project.toml index 6411e91a5..30c6df2e6 100644 --- a/Project.toml +++ b/Project.toml @@ -58,7 +58,7 @@ Reexport = "0.2, 1" SciMLBase = "2.4" SimpleNonlinearSolve = "0.1.23" SparseArrays = "1.9" -SparseDiffTools = "2.9" +SparseDiffTools = "2.11" StaticArraysCore = "1.4" UnPack = "1.0" Zygote = "0.6" diff --git a/src/gaussnewton.jl b/src/gaussnewton.jl index c857f2d23..012767dcf 100644 --- a/src/gaussnewton.jl +++ b/src/gaussnewton.jl @@ -1,5 +1,5 @@ """ - GaussNewton(; concrete_jac = nothing, linsolve = nothing, + GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) An advanced GaussNewton implementation with support for efficient handling of sparse @@ -30,6 +30,9 @@ for large-scale and numerically-difficult nonlinear least squares problems. preconditioners. For more information on specifying preconditioners for LinearSolve algorithms, consult the [LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/). + - `linesearch`: the line search algorithm to use. Defaults to [`LineSearch()`](@ref), + which means that no line search is performed. Algorithms from `LineSearches.jl` can be + used here directly, and they will be converted to the correct `LineSearch`. !!! warning @@ -40,16 +43,18 @@ for large-scale and numerically-difficult nonlinear least squares problems. ad::AD linsolve precs + linesearch end function set_ad(alg::GaussNewton{CJ}, ad) where {CJ} - return GaussNewton{CJ}(ad, alg.linsolve, alg.precs) + return GaussNewton{CJ}(ad, alg.linsolve, alg.precs, alg.linesearch) end function GaussNewton(; concrete_jac = nothing, linsolve = nothing, - precs = DEFAULT_PRECS, adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) - return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs) + linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) + return GaussNewton{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) end @concrete mutable struct GaussNewtonCache{iip} <: AbstractNonlinearSolveCache{iip} @@ -78,6 +83,7 @@ end stats::NLStats tc_cache_1 tc_cache_2 + ls_cache end function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_::GaussNewton, @@ -107,7 +113,8 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem{uType, iip}, alg_:: return GaussNewtonCache{iip}(f, alg, u, copy(u), fu1, fu2, zero(fu1), du, p, uf, linsolve, J, JᵀJ, Jᵀf, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, - abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2) + abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache_1, tc_cache_2, + init_linesearch_cache(alg.linesearch, f, u, p, fu1, Val(iip))) end function perform_step!(cache::GaussNewtonCache{true}) @@ -128,7 +135,8 @@ function perform_step!(cache::GaussNewtonCache{true}) linu = _vec(du), p, reltol = cache.abstol) end cache.linsolve = linres.cache - @. u = u - du + α = perform_linesearch!(cache.ls_cache, u, du) + _axpy!(-α, du, u) f(cache.fu_new, u, p) check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev) @@ -169,7 +177,8 @@ function perform_step!(cache::GaussNewtonCache{false}) end cache.linsolve = linres.cache end - cache.u = @. u - cache.du # `u` might not support mutation + α = perform_linesearch!(cache.ls_cache, u, cache.du) + cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu_new = f(cache.u, p) check_and_update!(cache.tc_cache_1, cache, cache.fu_new, cache.u, cache.u_prev) diff --git a/src/linesearch.jl b/src/linesearch.jl index 760f67769..a2b396b06 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -122,7 +122,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe end function g!(u, fu) - op = VecJac((args...) -> f(args..., p), u; autodiff) + op = VecJac(f, u, p; fu = fu1, autodiff) if iip mul!(g₀, op, fu) return g₀ diff --git a/src/raphson.jl b/src/raphson.jl index 1b75d231e..a28dec699 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -1,5 +1,5 @@ """ - NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, + NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) An advanced NewtonRaphson implementation with support for efficient handling of sparse diff --git a/src/utils.jl b/src/utils.jl index 0b0447772..c9013ead5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -198,7 +198,7 @@ function __get_concrete_algorithm(alg, prob) use_sparse_ad ? AutoSparseFiniteDiff() : AutoFiniteDiff() else (use_sparse_ad ? AutoSparseForwardDiff : AutoForwardDiff)(; - tag = NonlinearSolveTag()) + tag = ForwardDiff.Tag(NonlinearSolveTag(), eltype(prob.u0))) end return set_ad(alg, ad) end diff --git a/test/nonlinear_least_squares.jl b/test/nonlinear_least_squares.jl index c7a02dc58..07a310196 100644 --- a/test/nonlinear_least_squares.jl +++ b/test/nonlinear_least_squares.jl @@ -27,14 +27,16 @@ prob_iip = NonlinearLeastSquaresProblem(NonlinearFunction(loss_function; resid_prototype = zero(y_target)), θ_init, x) nlls_problems = [prob_oop, prob_iip] -solvers = [ - GaussNewton(), - GaussNewton(; linsolve = LUFactorization()), - LevenbergMarquardt(), - LevenbergMarquardt(; linsolve = LUFactorization()), - LeastSquaresOptimJL(:lm), - LeastSquaresOptimJL(:dogleg), -] +solvers = vec(Any[GaussNewton(; linsolve, linesearch) + for linsolve in [nothing, LUFactorization()], +linesearch in [Static(), BackTracking(), HagerZhang(), StrongWolfe(), MoreThuente()]]) +append!(solvers, + [ + LevenbergMarquardt(), + LevenbergMarquardt(; linsolve = LUFactorization()), + LeastSquaresOptimJL(:lm), + LeastSquaresOptimJL(:dogleg), + ]) for prob in nlls_problems, solver in solvers @time sol = solve(prob, solver; maxiters = 10000, abstol = 1e-8)