Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nlfunc to ODEFunction #800

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlfunc`: a `NonlinearFunction`

## iip: In-Place vs Out-Of-Place

Expand Down Expand Up @@ -401,8 +402,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -423,6 +424,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -519,8 +521,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, IProb, UIProb, IProbMap, IProbPmap} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLF} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -543,6 +545,7 @@ struct SplitFunction{
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
nlfunc::NLF
end

@doc doc"""
Expand Down Expand Up @@ -2420,7 +2423,8 @@ function ODEFunction{iip, specialize}(f;
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlfunc = __has_nlfunc(f) ? f.nlfunc : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2478,11 +2482,11 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlfunc)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2491,13 +2495,15 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap), typeof(initializeprobpmap)}(_f, mass_matrix,
typeof(sys), typeof(initializeprob),
typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlfunc)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
observed, _colorvec, sys, initializeprob, initializeprobmap, nlfunc)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2508,11 +2514,12 @@ function ODEFunction{iip, specialize}(f;
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(_f, mass_matrix, analytic, tgrad, jac,
typeof(initializeprobpmap),
typeof(nlfunc)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap)
initializeprobpmap, nlfunc)
end
end

Expand All @@ -2529,13 +2536,13 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any, Any, Any}(
typeof(f.sys), Any, Any, Any, Any, Any}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.update_initializeprob!, f.initializeprobmap,
f.initializeprobpmap)
f.initializeprobpmap, f.nlfunc)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
Expand All @@ -2545,11 +2552,12 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
typeof(f.initializeprobmap),
typeof(f.initializeprobpmap)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.initializeprobpmap),
typeof(f.nlfunc)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
f.initializeprobmap, f.initializeprobpmap)
f.initializeprobmap, f.initializeprobpmap, f.nlfunc)
end
end

Expand Down Expand Up @@ -4370,6 +4378,7 @@ __has_initializeprob(f) = isdefined(f, :initializeprob)
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
__has_nlfunc(f) = isdefined(f, :nl_func)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand Down
Loading