Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sys field to function types I missed #220

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ ODEFunction{iip,recompile}(f;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -259,7 +260,8 @@ SplitFunction{iip,recompile}(f1,f2;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the functions `f_i` themselves are required. These functions should
Expand Down Expand Up @@ -380,7 +382,8 @@ DynamicalODEFunction{iip,recompile}(f1,f2;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the functions `f_i` themselves are required. These functions should
Expand Down Expand Up @@ -487,7 +490,8 @@ DDEFunction{iip,recompile}(f;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -596,7 +600,8 @@ DynamicalDDEFunction{iip,recompile}(f1,f2;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the functions `f_i` themselves are required. These functions should
Expand Down Expand Up @@ -765,7 +770,8 @@ SDEFunction{iip,recompile}(f,g;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -876,7 +882,8 @@ SplitSDEFunction{iip,recompile}(f1,f2,g;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. All of the remaining functions
Expand Down Expand Up @@ -989,7 +996,8 @@ DynamicalSDEFunction{iip,recompile}(f1,f2;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the functions `f_i` themselves are required. These functions should
Expand Down Expand Up @@ -1101,7 +1109,8 @@ RODEFunction{iip,recompile}(f;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -1203,7 +1212,8 @@ DAEFunction{iip,recompile}(f;
sparsity=jac_prototype,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -1349,7 +1359,8 @@ SDDEFunction{iip,recompile}(f,g;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -1403,7 +1414,7 @@ For more details on this argument, see the ODEFunction documentation.
The fields of the DDEFunction type directly match the names of the inputs.
"""
struct SDDEFunction{iip, F, G, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, GG, S, O,
TCV} <: AbstractSDDEFunction{iip}
TCV, SYS} <: AbstractSDDEFunction{iip}
f::F
g::G
mass_matrix::TMM
Expand All @@ -1421,6 +1432,7 @@ struct SDDEFunction{iip, F, G, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ,
syms::S
observed::O
colorvec::TCV
sys::SYS
end

"""
Expand Down Expand Up @@ -1454,7 +1466,8 @@ NonlinearFunction{iip, recompile}(f;
paramjac = nothing,
syms = nothing,
indepsym = nothing,
colorvec = nothing)
colorvec = nothing,
sys = nothing)
```

Note that only the function `f` itself is required. This function should
Expand Down Expand Up @@ -1542,7 +1555,8 @@ OptimizationFunction{iip}(f, adtype::AbstractADType = NoAD();
cons_hess_prototype = nothing,
syms = nothing, hess_colorvec = nothing,
cons_jac_colorvec = nothing,
cons_hess_colorvec = nothing)
cons_hess_colorvec = nothing,
sys = nothing)
```

- `adtype`: see the section "Defining Optimization Functions via AD"
Expand Down Expand Up @@ -1998,8 +2012,9 @@ end
function DiscreteFunction{iip, false}(f;
analytic = nothing,
syms = nothing,
observed = DEFAULT_OBSERVED) where {iip}
DiscreteFunction{iip, Any, Any, Any, Any}(f, analytic, syms, observed)
observed = DEFAULT_OBSERVED,
sys = nothing) where {iip}
DiscreteFunction{iip, Any, Any, Any, Any, Any}(f, analytic, syms, observed, sys)
end
function DiscreteFunction{iip}(f; kwargs...) where {iip}
DiscreteFunction{iip, RECOMPILE_BY_DEFAULT}(f; kwargs...)
Expand Down Expand Up @@ -3177,7 +3192,12 @@ function Base.convert(::Type{DiscreteFunction}, f)
else
observed = DEFAULT_OBSERVED
end
DiscreteFunction(f; analytic = analytic, syms = syms, observed = observed)
if __has_sys(f)
sys = f.sys
else
sys = nothing
end
DiscreteFunction(f; analytic = analytic, syms = syms, observed = observed, sys = sys)
end
function Base.convert(::Type{DiscreteFunction{iip}}, f) where {iip}
if __has_analytic(f)
Expand All @@ -3195,8 +3215,13 @@ function Base.convert(::Type{DiscreteFunction{iip}}, f) where {iip}
else
observed = DEFAULT_OBSERVED
end
if __has_sys(f)
sys = f.sys
else
sys = nothing
end
DiscreteFunction{iip, RECOMPILE_BY_DEFAULT}(f; analytic = analytic, syms = syms,
observed = observed)
observed = observed, sys = sys)
end

function Base.convert(::Type{DAEFunction}, f)
Expand Down