Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional caching #247

Merged
merged 9 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions examples/chemistry/brusselator_implicit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
using ACSets
using CombinatorialSpaces
using DiagrammaticEquations
using Decapodes
using MLStyle
using OrdinaryDiffEq
using Symbolics
using Krylov
using LinearSolve
using LinearAlgebra
using CairoMakie
using ComponentArrays
using GeometryBasics: Point2, Point3
Point2D = Point2{Float64}
Point3D = Point3{Float64}

using Logging: global_logger
using TerminalLoggers: TerminalLogger
global_logger(TerminalLogger())

Brusselator = @decapode begin
(U, V)::Form0
U2V::Form0
(U̇, V̇)::Form0

(α)::Constant
F::Parameter

U2V == U .^ 2 .* V

U̇ == 1 + U2V - (4.4 * U) + (α * Δ(U)) + F
V̇ == (3.4 * U) - U2V + (α * Δ(V))
∂ₜ(U) == U̇
∂ₜ(V) == V̇
end

x = 1
dx = 0.005
s = triangulated_grid(x,x,dx,dx,Point3D);
sd = EmbeddedDeltaDualComplex2D{Bool,Float64,Point2D}(s);
subdivide_duals!(sd, Circumcenter());

U = map(sd[:point]) do (_,y)
22 * (y * (1-y))^(3/2)
end

V = map(sd[:point]) do (x,_)
27 * (x * (1-x))^(3/2)
end

F₁ = map(sd[:point]) do (x,y)
(x-0.3)^2 + (y-0.6)^2 ≤ (0.1)^2 ? 5.0 : 0.0
end

F₂ = zeros(nv(sd))

constants_and_parameters = (
α = 0.001,
F = t -> t ≥ 1.1 ? F₁ : F₂)

sim = evalsim(Brusselator, can_prealloc=false)
fₘ = sim(sd, nothing, DiagonalHodge())

u₀ = ComponentArray(U=U, V=V)

du₀ = copy(u₀)
jac_sparsity = Symbolics.jacobian_sparsity((du, u) -> fₘ(du, u, constants_and_parameters, 0.0), du₀, u₀)
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved

f = ODEFunction(fₘ; jac_prototype = float.(jac_sparsity))
# tₑ = 11.5
tₑ = 20
@info("Solving")
prob = ODEProblem(f, u₀, (0, tₑ), constants_and_parameters)
soln = solve(prob, FBDF(linsolve = KLUFactorization()), progress=true, progress_steps=1, saveat=0.1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave a comment if you expect this to take longer than other solvers and that this is just for demonstration purposes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've shrunk it down a bit and it only takes 22 seconds to solve.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the relative time to the other solver

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a bit longer, but I want to showcase how to use implicit methods so I've tagged this as demonstration code.

@info("Done")
soln.stats

function save_dynamics(save_file_name)
time = Observable(0.0)
u = @lift(soln($time).U)
f = Figure()
ax = CairoMakie.Axis(f[1,1], title = @lift("Brusselator U Concentration at Time $($time)"))
gmsh = mesh!(ax, s, color=u, colormap=:jet, colorrange=extrema(soln(0.0).U))
Colorbar(f[1,2], gmsh)
timestamps = range(0, tₑ, step=1e-1)
record(f, save_file_name, timestamps; framerate = 30) do t
time[] = t
end
end
save_dynamics("brusselator_U.gif")
75 changes: 35 additions & 40 deletions src/simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ end
# ! WARNING: Do not pass this an inplace function without setting equality to :.=
Base.Expr(c::UnaryCall) = begin
operator = c.operator
if(c.equality == :.=)
if c.equality == :.=
# TODO: Generalize to inplacable functions
if(operator == add_inplace_stub(:⋆₁⁻¹)) # Since inverse hodge Geo is a solver
if operator == add_inplace_stub(:⋆₁⁻¹) # Since inverse hodge Geo is a solver
Expr(:call, c.operator, c.output, c.input)
elseif(operator == :.-)
elseif operator == :.-
Expr(c.equality, c.output, Expr(:call, operator, c.input))
else # TODO: Add check that this operator is a matrix
Expr(:call, :mul!, c.output, operator, c.input)
Expand All @@ -60,7 +60,7 @@ end
# ! WARNING: Do not pass this an inplace function without setting equality to :.=, vice versa
Base.Expr(c::BinaryCall) = begin
# These operators can be done in-place
if(c.equality == :.= && get_stub(c.operator) == GENSIM_INPLACE_STUB)
if c.equality == :.= && get_stub(c.operator) == GENSIM_INPLACE_STUB
return Expr(:call, c.operator, c.output, c.input1, c.input2)
end
return Expr(c.equality, c.output, Expr(:call, c.operator, c.input1, c.input2))
Expand Down Expand Up @@ -140,11 +140,11 @@ end

#= function get_form_number(d::SummationDecapode, var_id::Int)
type = d[var_id, :type]
if(type == :Form0)
if type == :Form0
return 0
elseif(type == :Form1)
elseif type == :Form1
return 1
elseif(type == :Form2)
elseif type == :Form2
return 2
end
return -1
Expand Down Expand Up @@ -200,11 +200,11 @@ function get_stub(var_name::Symbol)
var_str = String(var_name)
idx = findfirst("_", var_str)

if(isnothing(idx))
if isnothing(idx)
return NO_STUB_RETURN
end

if(first(idx) == 1)
if first(idx) == 1
throw(InvalidStubException(var_name))
end

Expand Down Expand Up @@ -376,7 +376,7 @@ Function that compiles the computation body. `d` is the input Decapode, `inputs`
in-place methods, `dimension` is the dimension of the problem (usually 1 or 2), `stateeltype` is the type of the state elements
(usually Float32 or Float64) and `code_target` determines what architecture the code is compiled for (either CPU or CUDA).
"""
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget)
function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Vector{AllocVecCall}, optimizable_dec_operators::Set{Symbol}, dimension::Int, stateeltype::DataType, code_target::AbstractGenerationTarget, can_prealloc::Bool)
# Get the Vars of the inputs (probably state Vars).
visited_Var = falses(nparts(d, :Var))

Expand Down Expand Up @@ -412,19 +412,15 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
tname = d[t, :name]

# TODO: Check to see if this is a DEC operator
if(operator in optimizable_dec_operators)
# push!(dec_matrices, operator)
if(is_form(d, t))
if can_prealloc && is_form(d, t)
if operator in optimizable_dec_operators
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = add_stub(GENSIM_INPLACE_STUB, operator)

push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
end
elseif(operator == :(-) || operator == :.-)
if(is_form(d, t))

elseif operator == :(-) || operator == :.-
equality = PROMOTE_ARITHMETIC_MAP[equality]
operator = PROMOTE_ARITHMETIC_MAP[operator]

push!(alloc_vectors, AllocVecCall(tname, d[t, :type], dimension, stateeltype, code_target))
end
end
Expand All @@ -448,39 +444,39 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
equality = :(=)

# TODO: Check to make sure that this logic never breaks
if(is_form(d, r))
if(operator == :(+) || operator == :(-) || operator == :.+ || operator == :.-)
if can_prealloc && is_form(d, r)
if operator == :(+) || operator == :(-) || operator == :.+ || operator == :.-
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))

# TODO: Do we want to support the ability of a user to use the backslash operator?
elseif(operator == :(*) || operator == :(/) || operator == :.* || operator == :./)
elseif operator == :(*) || operator == :(/) || operator == :.* || operator == :./
# ! WARNING: This part may break if we add more compiler types that have different
# ! operations for basic and broadcast modes, e.g. matrix multiplication vs broadcast
if(!is_infer(d, arg1) && !is_infer(d, arg2))
if !is_infer(d, arg1) && !is_infer(d, arg2)
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
elseif(operator in optimizable_dec_operators)
elseif operator in optimizable_dec_operators
operator = add_stub(GENSIM_INPLACE_STUB, operator)
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
end
end

# TODO: Clean this in another PR (with a @match maybe).
if(operator == :(*))
if operator == :(*)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if(operator == :(-))
if operator == :(-)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if(operator == :(/))
if operator == :(/)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end
if(operator == :(^))
if operator == :(^)
operator = PROMOTE_ARITHMETIC_MAP[operator]
end

Expand All @@ -503,7 +499,7 @@ function compile(d::SummationDecapode, inputs::Vector{Symbol}, alloc_vectors::Ve
equality = :(=)

# If result is a known form, broadcast addition
if(is_form(d, r))
if can_prealloc && is_form(d, r)
operator = PROMOTE_ARITHMETIC_MAP[operator]
equality = PROMOTE_ARITHMETIC_MAP[equality]
push!(alloc_vectors, AllocVecCall(rname, d[r, :type], dimension, stateeltype, code_target))
Expand Down Expand Up @@ -563,7 +559,7 @@ to the compiler.
"""
function resolve_types_compiler!(d::SummationDecapode)
d[:type] = map(d[:type]) do x
if(x == :Constant || x == :Parameter)
if x == :Constant || x == :Parameter
return :infer
end
return x
Expand Down Expand Up @@ -604,7 +600,7 @@ Collects all DEC operators that are concrete matrices.
"""
function init_dec_matrices!(d::SummationDecapode, dec_matrices::Vector{Symbol}, optimizable_dec_operators::Set{Symbol})
for op_name in vcat(d[:op1], d[:op2])
if(op_name in optimizable_dec_operators)
if op_name in optimizable_dec_operators
push!(dec_matrices, op_name)
end
end
Expand All @@ -629,7 +625,7 @@ function link_contract_operators(d::SummationDecapode, con_dec_operators::Set{Sy
compute_key = join(computation, " * ")

computation_name = get(compute_to_name, compute_key, :Error)
if(computation_name == :Error)
if computation_name == :Error
computation_name = add_stub(Symbol("GenSim-ConMat"), Symbol(curr_id))
get!(compute_to_name, compute_key, computation_name)
push!(con_dec_operators, computation_name)
Expand All @@ -656,7 +652,7 @@ function hook_LCO_inplace(computation_name::Symbol, computation::Vector{Symbol},
end

function generate_parentheses_multiply(list)
if(length(list) == 1)
if length(list) == 1
return list[1]
else
return Expr(:call, :*, generate_parentheses_multiply(list[1:end-1]), list[end])
Expand Down Expand Up @@ -688,7 +684,7 @@ state variables and literals in the Decapode.
Optional keyword arguments are `dimension`, which is the dimension of the problem and defaults to 2D, `stateeltype`, which is the element type of the state forms and
defaults to Float64 and `code_target`, which is the intended architecture target for the generated code, defaulting to regular CPU compatible code.
"""
function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget())
function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget(), can_prealloc::Bool = true)
GeorgeR227 marked this conversation as resolved.
Show resolved Hide resolved

(1 <= dimension <= 2) ||
throw(UnsupportedDimensionException(dimension))
Expand Down Expand Up @@ -736,7 +732,7 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
union!(optimizable_dec_operators, contracted_dec_operators, extra_dec_operators)

# Compilation of the simulation
equations = compile(gen_d, input_vars, alloc_vectors, optimizable_dec_operators, dimension, stateeltype, code_target)
equations = compile(gen_d, input_vars, alloc_vectors, optimizable_dec_operators, dimension, stateeltype, code_target, can_prealloc)
data = post_process_vector_allocs(alloc_vectors, code_target)

func_defs = compile_env(gen_d, dec_matrices, contracted_dec_operators, code_target)
Expand All @@ -758,20 +754,19 @@ function gensim(user_d::SummationDecapode, input_vars::Vector{Symbol}; dimension
end
end

gather_inputs(d::SummationDecapode) = vcat(infer_state_names(d), d[incident(d, :Literal, :type), :name])

gensim(c::Collage; dimension::Int=2) =
gensim(collate(c); dimension=dimension)

""" function gensim(d::SummationDecapode; dimension::Int=2)

Generate a simulation function from the given Decapode. The returned function can then be combined with a mesh and a function describing function mappings to return a simulator to be passed to `solve`.
"""
gensim(d::SummationDecapode; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget()) =
gensim(d, vcat(infer_state_names(d), d[incident(d, :Literal, :type), :name]), dimension=dimension, stateeltype=stateeltype, code_target=code_target)
gensim(d::SummationDecapode; kwargs...) =
gensim(d, gather_inputs(d); kwargs...)

evalsim(d::SummationDecapode; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget()) =
eval(gensim(d, dimension=dimension, stateeltype=stateeltype, code_target=code_target))
evalsim(d::SummationDecapode, input_vars::Vector{Symbol}; dimension::Int=2, stateeltype::DataType = Float64, code_target::AbstractGenerationTarget = CPUTarget()) =
eval(gensim(d, input_vars, dimension=dimension, stateeltype=stateeltype, code_target=code_target))
evalsim(args...; kwargs...) = eval(gensim(args...; kwargs...))

"""
function find_unreachable_tvars(d)
Expand Down
Loading
Loading