Skip to content

Commit

Permalink
Merge pull request #2911 from BenChung/improve-events
Browse files Browse the repository at this point in the history
Support more of the SciMLBase events API
  • Loading branch information
ChrisRackauckas authored Aug 2, 2024
2 parents 3cef655 + 7053b34 commit 55e8ad6
Show file tree
Hide file tree
Showing 2 changed files with 433 additions and 41 deletions.
215 changes: 174 additions & 41 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,24 +76,62 @@ end
#################################### continuous events #####################################

const NULL_AFFECT = Equation[]
"""
SymbolicContinuousCallback(eqs::Vector{Equation}, affect, affect_neg, rootfind)
A [`ContinuousCallback`](@ref SciMLBase.ContinuousCallback) specified symbolically. Takes a vector of equations `eq`
as well as the positive-edge `affect` and negative-edge `affect_neg` that apply when *any* of `eq` are satisfied.
By default `affect_neg = affect`; to only get rising edges specify `affect_neg = nothing`.
Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`.
For compactness, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`.
A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`.
Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a
sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not
guaranteed to be triggered.
Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used
is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved
into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become
active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules.
The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be
triggered iff an edge is detected `prev_sign > 0`.
Affects (i.e. `affect` and `affect_neg`) can be specified as either:
* A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations.
* A tuple `(f!, unknowns, read_parameters, modified_parameters, ctx)`, where:
+ `f!` is a function with signature `(integ, u, p, ctx)` that is called with the integrator, a state *index* vector `u` derived from `unknowns`, a parameter *index* vector `p` derived from `read_parameters`, and the `ctx` that was given at construction time. Note that `ctx` is aliased between instances.
+ `unknowns` is a vector of symbolic unknown variables and optionally their aliases (e.g. if the model was defined with `@variables x(t)` then a valid value for `unknowns` would be `[x]`). A variable can be aliased with a pair `x => :y`. The indices of these `unknowns` will be passed to `f!` in `u` in a named tuple; in the earlier example, if we pass `[x]` as `unknowns` then `f!` can access `x` as `integ.u[u.x]`. If no alias is specified the name of the index will be the symbol version of the variable name.
+ `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`.
+ `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition.
+ `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem.
"""
struct SymbolicContinuousCallback
eqs::Vector{Equation}
affect::Union{Vector{Equation}, FunctionalAffect}
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT)
new(eqs, make_affect(affect))
affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing}
rootfind::SciMLBase.RootfindOpt
function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT,
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
new(eqs, make_affect(affect), make_affect(affect_neg), rootfind)
end # Default affect to nothing
end
make_affect(affect) = affect
make_affect(affect::Tuple) = FunctionalAffect(affect...)
make_affect(affect::NamedTuple) = FunctionalAffect(; affect...)

function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback)
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)
isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) &&
isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)
end
Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs)
function Base.hash(cb::SymbolicContinuousCallback, s::UInt)
s = foldr(hash, cb.eqs, init = s)
cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s)
s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) :
hash(cb.affect_neg, s)
hash(cb.rootfind, s)
end

to_equation_vector(eq::Equation) = [eq]
Expand All @@ -108,6 +146,14 @@ function SymbolicContinuousCallback(args...)
end # wrap eq in vector
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough
function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT;
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
SymbolicContinuousCallback(eqs=[eqs], affect=affect, affect_neg=affect_neg, rootfind=rootfind)
end
function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT;
affect_neg = affect, rootfind = SciMLBase.LeftRootFind)
SymbolicContinuousCallback(eqs=eqs, affect=affect, affect_neg=affect_neg, rootfind=rootfind)
end

SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
Expand All @@ -130,12 +176,20 @@ function affects(cbs::Vector{SymbolicContinuousCallback})
mapreduce(affects, vcat, cbs, init = Equation[])
end

affect_negs(cb::SymbolicContinuousCallback) = cb.affect_neg
function affect_negs(cbs::Vector{SymbolicContinuousCallback})
mapreduce(affect_negs, vcat, cbs, init = Equation[])
end

namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af]
namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s)
namespace_affects(::Nothing, s) = nothing

function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback
SymbolicContinuousCallback(namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s))
SymbolicContinuousCallback(
namespace_equation.(equations(cb), (s,)),
namespace_affects(affects(cb), s),
namespace_affects(affect_negs(cb), s))
end

"""
Expand All @@ -159,7 +213,7 @@ function continuous_events(sys::AbstractSystem)
filter(!isempty, cbs)
end

#################################### continuous events #####################################
#################################### discrete events #####################################

struct SymbolicDiscreteCallback
# condition can be one of:
Expand Down Expand Up @@ -462,12 +516,38 @@ function generate_rootfinding_callback(sys::AbstractODESystem, dvs = unknowns(sy
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
"""
Generate a single rootfinding callback; this happens if there is only one equation in `cbs` passed to
generate_rootfinding_callback and thus we can produce a ContinuousCallback instead of a VectorContinuousCallback.
"""
function generate_single_rootfinding_callback(
eq, cb, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
if !isequal(eq.lhs, 0)
eq = 0 ~ eq.lhs - eq.rhs
end

rf_oop, rf_ip = generate_custom_function(
sys, [eq.rhs], dvs, ps; expression = Val{false}, kwargs...)
affect_function = compile_affect_fn(cb, sys, dvs, ps, kwargs)
cond = function (u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, parameter_values(integ), t)
tmp[1]
else
rf_oop(u, parameter_values(integ), t)
end
end
return ContinuousCallback(
cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind)
end

function generate_vector_rootfinding_callback(
cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
# fuse equations to create VectorContinuousCallback
eqs = reduce(vcat, eqs)
# rewrite all equations as 0 ~ interesting stuff
Expand All @@ -477,45 +557,99 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow
end

rhss = map(x -> x.rhs, eqs)
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
_, rf_ip = generate_custom_function(
sys, rhss, dvs, ps; expression = Val{false}, kwargs...)

affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn(
cb,
sys,
dvs,
ps,
kwargs)
for cb in cbs]
cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ), t)
end

rf_oop, rf_ip = generate_custom_function(sys, rhss, dvs, ps; expression = Val{false},
kwargs...)
# since there may be different number of conditions and affects,
# we build a map that translates the condition eq. number to the affect number
eq_ind2affect = reduce(vcat,
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
@assert length(eq_ind2affect) == length(eqs)
@assert maximum(eq_ind2affect) == length(affect_functions)

affect_functions = map(cbs) do cb # Keep affect function separate
eq_aff = affects(cb)
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_functions[eq_ind2affect[eq_ind]].affect(integ)
end
end

if length(eqs) == 1
cond = function (u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, parameter_values(integ), t)
tmp[1]
else
rf_oop(u, parameter_values(integ), t)
affect_neg = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_neg = affect_functions[eq_ind2affect[eq_ind]].affect_neg
if isnothing(affect_neg)
return # skip if the neg function doesn't exist - don't want to split this into a separate VCC because that'd break ordering
end
affect_neg(integ)
end
ContinuousCallback(cond, affect_functions[])
end
return VectorContinuousCallback(
cond, affect, affect_neg, length(eqs), rootfind = rootfind)
end

"""
Compile a single continuous callback affect function(s).
"""
function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs)
eq_aff = affects(cb)
eq_neg_aff = affect_negs(cb)
affect = compile_affect(eq_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
if eq_neg_aff === eq_aff
affect_neg = affect
elseif isnothing(eq_neg_aff)
affect_neg = nothing
else
cond = function (out, u, t, integ)
rf_ip(out, u, parameter_values(integ), t)
affect_neg = compile_affect(
eq_neg_aff, sys, dvs, ps; expression = Val{false}, kwargs...)
end
(affect = affect, affect_neg = affect_neg)
end

function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys),
ps = full_parameters(sys); kwargs...)
eqs = map(cb -> flatten_equations(cb.eqs), cbs)
num_eqs = length.(eqs)
total_eqs = sum(num_eqs)
(isempty(eqs) || total_eqs == 0) && return nothing
if total_eqs == 1
# find the callback with only one eq
cb_ind = findfirst(>(0), num_eqs)
if isnothing(cb_ind)
error("Inconsistent state in affect compilation; one equation but no callback with equations?")
end
cb = cbs[cb_ind]
return generate_single_rootfinding_callback(cb.eqs[], cb, sys, dvs, ps; kwargs...)
end

# since there may be different number of conditions and affects,
# we build a map that translates the condition eq. number to the affect number
eq_ind2affect = reduce(vcat,
[fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
@assert length(eq_ind2affect) == length(eqs)
@assert maximum(eq_ind2affect) == length(affect_functions)
# group the cbs by what rootfind op they use
# groupby would be very useful here, but alas
cb_classes = Dict{
@NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}()
for cb in cbs
push!(
get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)),
cb)
end

affect = let affect_functions = affect_functions, eq_ind2affect = eq_ind2affect
function (integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_functions[eq_ind2affect[eq_ind]](integ)
end
end
VectorContinuousCallback(cond, affect, length(eqs))
# generate the callbacks out; we sort by the equivalence class to ensure a deterministic preference order
compiled_callbacks = map(collect(pairs(sort!(
OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class)
return generate_vector_rootfinding_callback(
cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...)
end
if length(compiled_callbacks) == 1
return compiled_callbacks[]
else
return CallbackSet(compiled_callbacks...)
end
end

Expand All @@ -529,7 +663,6 @@ function compile_user_affect(affect::FunctionalAffect, sys, dvs, ps; kwargs...)
ps_ind = Dict(reverse(en) for en in enumerate(ps))
p_inds = map(sym -> ps_ind[sym], parameters(affect))
end

# HACK: filter out eliminated symbols. Not clear this is the right thing to do
# (MTK should keep these symbols)
u = filter(x -> !isnothing(x[2]), collect(zip(unknowns_syms(affect), v_inds))) |>
Expand Down
Loading

0 comments on commit 55e8ad6

Please sign in to comment.