Skip to content

Commit

Permalink
fixup! refactor: turn tunables portion into a Vector{T}
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jul 30, 2024
1 parent c848028 commit b5be3e4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
12 changes: 8 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars]
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_pars]
),
expr.body,
false
Expand All @@ -275,7 +276,8 @@ function wrap_array_vars(
Let(
vcat(
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_pars]
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_pars]
),
expr.body,
false
Expand All @@ -288,8 +290,10 @@ function wrap_array_vars(
[],
Let(
vcat(
[k :(view($(expr.args[uind+1].name), $v)) for (k, v) in array_vars],
[k :(view($(expr.args[uind+2].name), $v)) for (k, v) in array_pars]
[k :(view($(expr.args[uind + 1].name), $v))
for (k, v) in array_vars],
[k :(view($(expr.args[uind + 2].name), $v))
for (k, v) in array_pars]
),
expr.body,
false
Expand Down
8 changes: 4 additions & 4 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ function IndexCache(sys::AbstractSystem)
insert_by_type!(
if ctype <: Real || ctype <: AbstractArray{<:Real}
if istunable(p, true) && Symbolics.shape(p) != Symbolics.Unknown &&
(ctype == Real || ctype <: AbstractFloat ||
ctype <: AbstractArray{Real} ||
ctype <: AbstractArray{<:AbstractFloat})
(ctype == Real || ctype <: AbstractFloat ||
ctype <: AbstractArray{Real} ||
ctype <: AbstractArray{<:AbstractFloat})
tunable_buffers
else
constant_buffers
Expand Down Expand Up @@ -450,7 +450,7 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
()
else
(BasicSymbolic[unwrap(variable(:DEF))
for _ in 1:(ic.tunable_buffer_size.length)],)
for _ in 1:(ic.tunable_buffer_size.length)],)
end
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
for temp in Iterators.flatten(ic.discrete_buffer_sizes))
Expand Down
10 changes: 6 additions & 4 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ function MTKParameters(
end
dep_exprs = identity.(dep_exprs)
psyms = reorder_parameters(ic, full_parameters(sys))
update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true}, wrap_code = wrap_array_vars(sys, dep_exprs; dvs = nothing))
update_fn_exprs = build_function(dep_exprs, psyms..., expression = Val{true},
wrap_code = wrap_array_vars(sys, dep_exprs; dvs = nothing))

update_function_oop, update_function_iip = eval_or_rgf.(
update_fn_exprs; eval_expression, eval_module)
Expand Down Expand Up @@ -306,8 +307,8 @@ function SciMLStructures.replace!(::SciMLStructures.Tunable, p::MTKParameters, n
end

for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 2)
(SciMLStructures.Constants, :constant, 1)
(Nonnumeric, :nonnumeric, 1)]
(SciMLStructures.Constants, :constant, 1)
(Nonnumeric, :nonnumeric, 1)]
@eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters)
as_vector = buffer_to_arraypartition(p.$field)
repack = let as_vector = as_vector, p = p
Expand Down Expand Up @@ -396,7 +397,8 @@ function SymbolicIndexingInterface.set_parameter!(
k, l... = k
if isempty(l)
if validate_size && size(val) !== size(p.discrete[i][j][k])
throw(InvalidParameterSizeException(size(p.discrete[i][j][k]), size(val)))
throw(InvalidParameterSizeException(
size(p.discrete[i][j][k]), size(val)))
end
p.discrete[i][j][k] = val
else
Expand Down

0 comments on commit b5be3e4

Please sign in to comment.