Skip to content

Commit

Permalink
add nlfunc to ODEFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Oct 14, 2024
1 parent 17f4548 commit 3ded92b
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlfunc`: a `NonlinearFunction`
## iip: In-Place vs Out-Of-Place
Expand Down Expand Up @@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -423,6 +424,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -519,8 +521,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -543,6 +545,7 @@ struct SplitFunction{
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -2420,7 +2423,8 @@ function ODEFunction{iip, specialize}(f;
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2478,11 +2482,11 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlfunc)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2491,13 +2495,15 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
typeof(sys), typeof(initializeprob),
typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlfunc)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
observed, _colorvec, sys, initializeprob, initializeprobmap, nlfunc)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2508,11 +2514,12 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobpmap)
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlfunc)
end
end

Expand All @@ -2529,13 +2536,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any, Any, Any}(
typeof(f.sys), Any, Any, Any, Any, Any}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.update_initializeprob!, f.initializeprobmap,
f.initializeprobpmap)
f.initializeprobpmap, f.nlfunc)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2545,11 +2552,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
typeof(f.initializeprobmap),
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobpmap),
typof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
f.initializeprobmap, f.initializeprobpmap)
f.initializeprobmap, f.initializeprobpmap, f.nlfunc)
end
end

Expand Down Expand Up @@ -4370,6 +4378,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
__has_nlfunc(f) = isdefined(f, :nl_func)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down

0 comments on commit 3ded92b

Please sign in to comment.