diff --git a/Project.toml b/Project.toml index 9385b14a2..b4977a080 100644 --- a/Project.toml +++ b/Project.toml @@ -60,6 +60,7 @@ LeastSquaresOptim = "0.8" LineSearches = "7" LinearAlgebra = "<0.0.1, 1" LinearSolve = "2.12" +MaybeInplace = "0.1" NaNMath = "1" NonlinearProblemLibrary = "0.1" Pkg = "1" @@ -71,7 +72,7 @@ Reexport = "0.2, 1" SafeTestsets = "0.1" SciMLBase = "2.9" SciMLOperators = "0.3" -SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test +SimpleNonlinearSolve = "1" SparseArrays = "<0.0.1, 1" SparseDiffTools = "2.14" StaticArrays = "1" @@ -98,6 +99,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 278667790..c6b4fca66 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -17,6 +17,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work import ConcreteStructs: @concrete import EnumX: @enumx import FastBroadcast: @.. + import FiniteDiff import ForwardDiff import ForwardDiff: Dual import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A @@ -56,7 +57,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c cache.p = p if iip recursivecopy!(get_u(cache), u0) - cache.f(cache.fu1, get_u(cache), p) + cache.f(get_fu(cache), get_u(cache), p) else cache.u = __maybe_unaliased(u0, alias_u0) set_fu!(cache, cache.f(cache.u, p)) @@ -76,7 +77,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c if hasfield(typeof(cache), :ls_cache) # TODO: A more efficient way to do this - cache.ls_cache = init_linesearch_cache(cache.prob, cache.alg.linesearch, cache.f, + cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f, get_u(cache), p, get_fu(cache), Val(iip)) end diff --git a/src/dfsane.jl b/src/dfsane.jl index 570dd7ccd..689c24485 100644 --- a/src/dfsane.jl +++ b/src/dfsane.jl @@ -55,6 +55,7 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_ end @concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip} + f alg u u_cache @@ -110,8 +111,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args. termination_condition) trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...) - return DFSaneCache{iip}(alg, u, u_cache, u_cache_2, fu, fu_cache, du, history, f_norm, - f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), + return DFSaneCache{iip}(prob.f, alg, u, u_cache, u_cache_2, fu, fu_cache, du, history, + f_norm, f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ), T(alg.τ_min), T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace) end diff --git a/src/jacobian.jl b/src/jacobian.jl index cd84b5d1d..2174fbc8e 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -209,10 +209,14 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) end # jvp fallback scalar -__jacvec(args...; kwargs...) = JacVec(args...; kwargs...) -function __jacvec(uf, u::Number; autodiff, kwargs...) - @assert autodiff isa AutoForwardDiff "Only ForwardDiff is currently supported." - return JVPScalar(uf, u, autodiff) +function __jacvec(uf, u; autodiff, kwargs...) + if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff) + _ad = autodiff + autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(), + AutoFiniteDiff()) + @warn "$(_ad) not supported for JacVec. Using $(autodiff) instead." + end + return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...) end @concrete mutable struct JVPScalar @@ -221,10 +225,17 @@ end autodiff end -function Base.:*(jvp::JVPScalar, v) - T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u))) - out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v)) - return ForwardDiff.extract_derivative(T, out) +function Base.:*(jvp::JVPScalar, v::Number) + if jvp.autodiff isa AutoForwardDiff + T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u))) + out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v)) + return ForwardDiff.extract_derivative(T, out) + elseif jvp.autodiff isa AutoFiniteDiff + J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype) + return J * v + else + error("Only ForwardDiff & FiniteDiff is currently supported.") + end end # Generic Handling of Krylov Methods for Normal Form Linear Solves diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 9ed243d26..9087b0d53 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -255,7 +255,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, @bb u_cache_2 = similar(u) @bb u_cauchy = similar(u) @bb u_gauss_newton = similar(u) - @bb J_cache = similar(J) + J_cache = J isa SciMLOperators.AbstractSciMLOperator || + setindex_trait(J) === CannotSetindex() ? J : similar(J) @bb lr_mul_cache = similar(du) loss_new = loss diff --git a/src/utils.jl b/src/utils.jl index 4d8496015..56a976aa8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -178,7 +178,7 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip}, return fu end -function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip <: Bool} +function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip} if iip f(fu, u, p) return fu