Skip to content

Commit

Permalink
Fix most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 5, 2023
1 parent cefe5b0 commit ee15d80
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 15 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ LeastSquaresOptim = "0.8"
LineSearches = "7"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2.12"
MaybeInplace = "0.1"
NaNMath = "1"
NonlinearProblemLibrary = "0.1"
Pkg = "1"
Expand All @@ -71,7 +72,7 @@ Reexport = "0.2, 1"
SafeTestsets = "0.1"
SciMLBase = "2.9"
SciMLOperators = "0.3"
SimpleNonlinearSolve = "1" # FIXME: Don't update the version in this PR. Using it to test
SimpleNonlinearSolve = "1"
SparseArrays = "<0.0.1, 1"
SparseDiffTools = "2.14"
StaticArrays = "1"
Expand All @@ -98,6 +99,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
5 changes: 3 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_work
import ConcreteStructs: @concrete
import EnumX: @enumx
import FastBroadcast: @..
import FiniteDiff
import ForwardDiff
import ForwardDiff: Dual
import LinearSolve: ComposePreconditioner, InvPreconditioner, needs_concrete_A
Expand Down Expand Up @@ -56,7 +57,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c
cache.p = p
if iip
recursivecopy!(get_u(cache), u0)
cache.f(cache.fu1, get_u(cache), p)
cache.f(get_fu(cache), get_u(cache), p)
else
cache.u = __maybe_unaliased(u0, alias_u0)
set_fu!(cache, cache.f(cache.u, p))
Expand All @@ -76,7 +77,7 @@ function SciMLBase.reinit!(cache::AbstractNonlinearSolveCache{iip}, u0 = get_u(c

if hasfield(typeof(cache), :ls_cache)
# TODO: A more efficient way to do this
cache.ls_cache = init_linesearch_cache(cache.prob, cache.alg.linesearch, cache.f,
cache.ls_cache = init_linesearch_cache(cache.alg.linesearch, cache.f,
get_u(cache), p, get_fu(cache), Val(iip))
end

Expand Down
5 changes: 3 additions & 2 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
end

@concrete mutable struct DFSaneCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_cache
Expand Down Expand Up @@ -110,8 +111,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args.
termination_condition)
trace = init_nonlinearsolve_trace(alg, u, fu, nothing, du; kwargs...)

return DFSaneCache{iip}(alg, u, u_cache, u_cache_2, fu, fu_cache, du, history, f_norm,
f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ),
return DFSaneCache{iip}(prob.f, alg, u, u_cache, u_cache_2, fu, fu_cache, du, history,
f_norm, f_norm_0, alg.M, T(alg.σ_1), T(alg.σ_min), T(alg.σ_max), one(T), T(alg.γ),
T(alg.τ_min), T(alg.τ_max), alg.n_exp, prob.p, false, maxiters, internalnorm,
ReturnCode.Default, abstol, reltol, prob, NLStats(1, 0, 0, 0, 0), tc_cache, trace)
end
Expand Down
27 changes: 19 additions & 8 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,14 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
end

# jvp fallback scalar
__jacvec(args...; kwargs...) = JacVec(args...; kwargs...)
function __jacvec(uf, u::Number; autodiff, kwargs...)
@assert autodiff isa AutoForwardDiff "Only ForwardDiff is currently supported."
return JVPScalar(uf, u, autodiff)
function __jacvec(uf, u; autodiff, kwargs...)
if !(autodiff isa AutoForwardDiff || autodiff isa AutoFiniteDiff)
_ad = autodiff
autodiff = ifelse(ForwardDiff.can_dual(eltype(u)), AutoForwardDiff(),
AutoFiniteDiff())
@warn "$(_ad) not supported for JacVec. Using $(autodiff) instead."
end
return u isa Number ? JVPScalar(uf, u, autodiff) : JacVec(uf, u; autodiff, kwargs...)
end

@concrete mutable struct JVPScalar
Expand All @@ -221,10 +225,17 @@ end
autodiff
end

function Base.:*(jvp::JVPScalar, v)
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
return ForwardDiff.extract_derivative(T, out)
function Base.:*(jvp::JVPScalar, v::Number)
if jvp.autodiff isa AutoForwardDiff
T = typeof(ForwardDiff.Tag(typeof(jvp.uf), typeof(jvp.u)))
out = jvp.uf(ForwardDiff.Dual{T}(jvp.u, v))
return ForwardDiff.extract_derivative(T, out)
elseif jvp.autodiff isa AutoFiniteDiff
J = FiniteDiff.finite_difference_derivative(jvp.uf, jvp.u, jvp.autodiff.fdtype)
return J * v
else
error("Only ForwardDiff & FiniteDiff is currently supported.")
end
end

# Generic Handling of Krylov Methods for Normal Form Linear Solves
Expand Down
3 changes: 2 additions & 1 deletion src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion,
@bb u_cache_2 = similar(u)
@bb u_cauchy = similar(u)
@bb u_gauss_newton = similar(u)
@bb J_cache = similar(J)
J_cache = J isa SciMLOperators.AbstractSciMLOperator ||
setindex_trait(J) === CannotSetindex() ? J : similar(J)
@bb lr_mul_cache = similar(du)

loss_new = loss
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ function evaluate_f(prob::Union{NonlinearProblem{uType, iip},
return fu
end

function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip <: Bool}
function evaluate_f(f::F, u, p, ::Val{iip}; fu = nothing) where {F, iip}
if iip
f(fu, u, p)
return fu
Expand Down

0 comments on commit ee15d80

Please sign in to comment.