Skip to content

Commit

Permalink
Additional initialization path
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 16, 2023
1 parent da5aaff commit baaad84
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/function_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@ mutable struct TimeGradientWrapper{iip, fType, uType, P} <: AbstractSciMLFunctio
p::P
end

function TimeGradientWrapper{iip}(f::F, uprev, p) where {F, iip}

Check warning on line 7 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L7

Added line #L7 was not covered by tests
return TimeGradientWrapper{iip, F, typeof(uprev), typeof(p)}(f, uprev, p)
end
function TimeGradientWrapper(f::F, uprev, p) where {F}

Check warning on line 10 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L10

Added line #L10 was not covered by tests
return TimeGradientWrapper{isinplace(f, 4), F, typeof(uprev), typeof(p)}(f, uprev, p)
return TimeGradientWrapper{isinplace(f, 4)}(f, uprev, p)
end

(ff::TimeGradientWrapper{true})(t) = (du2 = similar(ff.uprev); ff.f(du2, ff.uprev, ff.p, t); du2)

Check warning on line 14 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L14

Added line #L14 was not covered by tests
Expand All @@ -19,9 +22,10 @@ mutable struct UJacobianWrapper{iip, fType, tType, P} <: AbstractSciMLFunction{i
p::P
end

function UJacobianWrapper(f::F, t, p) where {F}
return UJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
function UJacobianWrapper{iip}(f::F, t, p) where {F, iip}

Check warning on line 25 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L25

Added line #L25 was not covered by tests
return UJacobianWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)
end
UJacobianWrapper(f::F, t, p) where {F} = UJacobianWrapper{isinplace(f, 4)}(f, t, p)

(ff::UJacobianWrapper{true})(du1, uprev) = ff.f(du1, uprev, ff.p, ff.t)
(ff::UJacobianWrapper{true})(uprev) = (du1 = similar(uprev); ff.f(du1, uprev, ff.p, ff.t); du1)
Expand All @@ -37,8 +41,11 @@ mutable struct TimeDerivativeWrapper{iip, F, uType, P} <: AbstractSciMLFunction{
p::P
end

function TimeDerivativeWrapper{iip}(f::F, u, p) where {F, iip}
return TimeDerivativeWrapper{iip, F, typeof(u), typeof(p)}(f, u, p)

Check warning on line 45 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end
function TimeDerivativeWrapper(f::F, u, p) where {F}
return TimeDerivativeWrapper{isinplace(f, 4), F, typeof(u), typeof(p)}(f, u, p)
return TimeDerivativeWrapper{isinplace(f, 4)}(f, u, p)

Check warning on line 48 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end

(ff::TimeDerivativeWrapper{false})(t) = ff.f(ff.u, ff.p, t)
Expand All @@ -51,9 +58,10 @@ mutable struct UDerivativeWrapper{iip, F, tType, P} <: AbstractSciMLFunction{iip
p::P
end

function UDerivativeWrapper(f::F, t, p) where {F}
return UDerivativeWrapper{isinplace(f, 4), F, typeof(t), typeof(p)}(f, t, p)
function UDerivativeWrapper{iip}(f::F, t, p) where {F, iip}
return UDerivativeWrapper{iip, F, typeof(t), typeof(p)}(f, t, p)

Check warning on line 62 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L61-L62

Added lines #L61 - L62 were not covered by tests
end
UDerivativeWrapper(f::F, t, p) where {F} = UDerivativeWrapper{isinplace(f, 4)}(f, t, p)

Check warning on line 64 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L64

Added line #L64 was not covered by tests

(ff::UDerivativeWrapper{false})(u) = ff.f(u, ff.p, ff.t)
(ff::UDerivativeWrapper{true})(du1, u) = ff.f(du1, u, ff.p, ff.t)
Expand All @@ -65,9 +73,10 @@ mutable struct ParamJacobianWrapper{iip, fType, tType, uType} <: AbstractSciMLFu
u::uType
end

function ParamJacobianWrapper(f::F, t, u) where {F}
return ParamJacobianWrapper{isinplace(f, 4), F, typeof(t), typeof(u)}(f, t, u)
function ParamJacobianWrapper{iip}(f::F, t, u) where {F, iip}
return ParamJacobianWrapper{iip, F, typeof(t), typeof(u)}(f, t, u)

Check warning on line 77 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L76-L77

Added lines #L76 - L77 were not covered by tests
end
ParamJacobianWrapper(f::F, t, u) where {F} = ParamJacobianWrapper{isinplace(f, 4)}(f, t, u)

Check warning on line 79 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L79

Added line #L79 was not covered by tests

(ff::ParamJacobianWrapper{true})(du1, p) = ff.f(du1, ff.u, p, ff.t)
function (ff::ParamJacobianWrapper{true})(p)

Check warning on line 82 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L81-L82

Added lines #L81 - L82 were not covered by tests
Expand All @@ -82,9 +91,9 @@ mutable struct JacobianWrapper{iip, fType, pType} <: AbstractSciMLFunction{iip}
p::pType
end

function JacobianWrapper(f::F, p) where {F}
return JacobianWrapper{isinplace(f, 4), F, typeof(p)}(f, p)
end
JacobianWrapper{iip}(f::F, p) where {F, iip} = JacobianWrapper{iip, F, typeof(p)}(f, p)
JacobianWrapper(f::F, p) where {F} = JacobianWrapper{isinplace(f, 3)}(f, p)

Check warning on line 95 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L94-L95

Added lines #L94 - L95 were not covered by tests

(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)

Check warning on line 99 in src/function_wrappers.jl

View check run for this annotation

Codecov / codecov/patch

src/function_wrappers.jl#L97-L99

Added lines #L97 - L99 were not covered by tests

0 comments on commit baaad84

Please sign in to comment.