From 4d8cde4bcf84aa5f61d38df81cfce3a89f143b29 Mon Sep 17 00:00:00 2001 From: Brad Carman Date: Wed, 2 Aug 2023 17:28:21 -0400 Subject: [PATCH] p_type definition --- src/systems/diffeqs/abstractodesystem.jl | 10 ++++--- src/utils.jl | 27 ++++++++++++------- src/variables.jl | 4 +-- test/odesystem.jl | 34 ++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 16 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 30b237c807..53ec758f7e 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -684,8 +684,9 @@ Take dictionaries with initial conditions and parameters and convert them to num function get_u0_p(sys, u0map, parammap; + p_type = nothing, use_union = false, - tofloat = !use_union, + tofloat = !use_union & p_type === nothing, symbolic_u0 = false) eqs = equations(sys) dvs = states(sys) @@ -700,7 +701,7 @@ function get_u0_p(sys, else u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true) end - p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union) + p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union, type = p_type) p = p === nothing ? SciMLBase.NullParameters() : p u0, p, defs end @@ -713,8 +714,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; simplify = false, linenumbers = true, parallel = SerialForm(), eval_expression = true, + p_type = nothing, use_union = false, - tofloat = !use_union, + tofloat = !use_union & isnothing(p_type), symbolic_u0 = false, kwargs...) eqs = equations(sys) @@ -722,7 +724,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap; ps = parameters(sys) iv = get_iv(sys) - u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0) + u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union, symbolic_u0, p_type) if implicit_dae && du0map !== nothing ddvs = map(Differential(iv), dvs) diff --git a/src/utils.jl b/src/utils.jl index 4dc2a636df..26238b439c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -666,7 +666,10 @@ end throw(ArgumentError("$vars are either missing from the variable map or missing from the system's states/parameters list.")) end -function promote_to_concrete(vs; tofloat = true, use_union = false) +function promote_to_concrete(vs; + type::Union{Type{K}, Nothing} = nothing, + tofloat = type === nothing, + use_union = false) where {K} if isempty(vs) return vs end @@ -693,16 +696,20 @@ function promote_to_concrete(vs; tofloat = true, use_union = false) I = promote_type(I, E) end end - if tofloat && !has_array - C = float(C) - elseif has_array || (use_union && has_int && C !== I) - if has_array - C = Union{C, array_T} + if type === nothing + if tofloat && !has_array + C = float(C) + elseif has_array || (use_union && has_int && C !== I) + if has_array + C = Union{C, array_T} + end + if has_int + C = Union{C, I} + end + return copyto!(similar(vs, C), vs) end - if has_int - C = Union{C, I} - end - return copyto!(similar(vs, C), vs) + else + C = K end convert.(C, vs) end diff --git a/src/variables.jl b/src/variables.jl index 4d11193462..15bcf8984f 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -58,7 +58,7 @@ applicable. """ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, toterm = default_toterm, promotetoconcrete = nothing, - tofloat = true, use_union = false) + tofloat = true, use_union = false, kwargs...) varlist = collect(map(unwrap, varlist)) # Edge cases where one of the arguments is effectively empty. @@ -89,7 +89,7 @@ function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true, promotetoconcrete === nothing && (promotetoconcrete = container_type <: AbstractArray) if promotetoconcrete - vals = promote_to_concrete(vals; tofloat = tofloat, use_union = use_union) + vals = promote_to_concrete(vals; tofloat, use_union, kwargs...) end if isempty(vals) diff --git a/test/odesystem.jl b/test/odesystem.jl index a45a45b1b0..f3e758958c 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1012,3 +1012,37 @@ let prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0)) @test !isnothing(prob.f.sys) end + +# p_type +let + # needs ModelingToolkitStandardLibrary > v2.1.0 + using ModelingToolkitStandardLibrary.Blocks: SampledData, Parameter, Integrator + + dt = 4e-4 + t_end = 10.0 + time = 0:dt:t_end + x = @. time^2 + 1.0 + + @parameters t + D = Differential(t) + + vars = @variables y(t)=1 dy(t)=0 ddy(t)=0 + @named src = SampledData(Float64) + @named int = Integrator() + @named iosys = ODESystem([y ~ src.output.u + D(y) ~ dy + D(dy) ~ ddy + connect(src.output, int.input)], + t, + systems = [int, src]) + sys = structural_simplify(iosys) + s = complete(iosys) + prob = ODEProblem(sys, + [], + (0.0, t_end), + [s.src.buffer => Parameter(x, dt)]; + p_type = Parameter{Float64}) + + @test eltype(prob.p) == Parameter{Float64} + @test eltype(prob.u0) == Float64 +end