Skip to content

Commit

Permalink
p_type definition
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad Carman committed Aug 2, 2023
1 parent b9f285d commit 4d8cde4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 16 deletions.
10 changes: 6 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ Take dictionaries with initial conditions and parameters and convert them to num
function get_u0_p(sys,
u0map,
parammap;
p_type = nothing,
use_union = false,
tofloat = !use_union,
tofloat = !use_union & p_type === nothing,
symbolic_u0 = false)
eqs = equations(sys)
dvs = states(sys)
Expand All @@ -700,7 +701,7 @@ function get_u0_p(sys,
else
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
end
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union, type = p_type)
p = p === nothing ? SciMLBase.NullParameters() : p
u0, p, defs
end
Expand All @@ -713,16 +714,17 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
simplify = false,
linenumbers = true, parallel = SerialForm(),
eval_expression = true,
p_type = nothing,
use_union = false,
tofloat = !use_union,
tofloat = !use_union & isnothing(p_type),
symbolic_u0 = false,
kwargs...)
eqs = equations(sys)
dvs = states(sys)
ps = parameters(sys)
iv = get_iv(sys)

u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0)
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0, p_type)

if implicit_dae && du0map !== nothing
ddvs = map(Differential(iv), dvs)
Expand Down
27 changes: 17 additions & 10 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,10 @@ end
throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list."))
end

function promote_to_concrete(vs; tofloat = true, use_union = false)
function promote_to_concrete(vs;
type::Union{Type{K}, Nothing} = nothing,
tofloat = type === nothing,
use_union = false) where {K}
if isempty(vs)
return vs
end
Expand All @@ -693,16 +696,20 @@ function promote_to_concrete(vs; tofloat = true, use_union = false)
I = promote_type(I, E)
end
end
if tofloat && !has_array
C = float(C)
elseif has_array || (use_union && has_int && C !== I)
if has_array
C = Union{C, array_T}
if type === nothing
if tofloat && !has_array
C = float(C)

Check warning on line 701 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L701

Added line #L701 was not covered by tests
elseif has_array || (use_union && has_int && C !== I)
if has_array
C = Union{C, array_T}

Check warning on line 704 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L703-L704

Added lines #L703 - L704 were not covered by tests
end
if has_int
C = Union{C, I}

Check warning on line 707 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L706-L707

Added lines #L706 - L707 were not covered by tests
end
return copyto!(similar(vs, C), vs)
end
if has_int
C = Union{C, I}
end
return copyto!(similar(vs, C), vs)
else
C = K

Check warning on line 712 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L712

Added line #L712 was not covered by tests
end
convert.(C, vs)
end
Expand Down
4 changes: 2 additions & 2 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
tofloat = true, use_union = false, kwargs...)
varlist = collect(map(unwrap, varlist))

# Edge cases where one of the arguments is effectively empty.
Expand Down Expand Up @@ -89,7 +89,7 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,

promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray)
if promotetoconcrete
vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union)
vals = promote_to_concrete(vals; tofloat, use_union, kwargs...)
end

if isempty(vals)
Expand Down
34 changes: 34 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,37 @@ let
prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0))
@test !isnothing(prob.f.sys)
end

# p_type
let
# needs ModelingToolkitStandardLibrary > v2.1.0
using ModelingToolkitStandardLibrary.Blocks: SampledData, Parameter, Integrator

dt = 4e-4
t_end = 10.0
time = 0:dt:t_end
x = @. time^2 + 1.0

@parameters t
D = Differential(t)

vars = @variables y(t)=1 dy(t)=0 ddy(t)=0
@named src = SampledData(Float64)
@named int = Integrator()
@named iosys = ODESystem([y ~ src.output.u
D(y) ~ dy
D(dy) ~ ddy
connect(src.output, int.input)],
t,
systems = [int, src])
sys = structural_simplify(iosys)
s = complete(iosys)
prob = ODEProblem(sys,
[],
(0.0, t_end),
[s.src.buffer => Parameter(x, dt)];
p_type = Parameter{Float64})

@test eltype(prob.p) == Parameter{Float64}
@test eltype(prob.u0) == Float64
end

0 comments on commit 4d8cde4

Please sign in to comment.