Skip to content

Commit

Permalink
Merge pull request #2240 from SciML/extrapolation
Browse files Browse the repository at this point in the history
Move extrapolation methods to an add-on library
  • Loading branch information
ChrisRackauckas authored Jun 6, 2024
2 parents 5afe786 + 4c83013 commit 8db5efe
Show file tree
Hide file tree
Showing 23 changed files with 868 additions and 768 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
- AlgConvergence_II
- AlgConvergence_III
- Downstream
- Extrapolation
- ODEInterfaceRegression
- Multithreading
version:
Expand Down
24 changes: 24 additions & 0 deletions lib/OrdinaryDiffEqExtrapolation/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name = "OrdinaryDiffEqExtrapolation"
uuid = "becaefa8-8ca2-5cf9-886d-c06f3d2bd2c4"
authors = ["Chris Rackauckas <[email protected]>", "Yingbo Ma <[email protected]>"]
version = "1.0.0"

[deps]
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"

[compat]
julia = "1.10"

[extras]
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DiffEqDevTools", "Random", "SafeTestsets", "Test"]
57 changes: 57 additions & 0 deletions lib/OrdinaryDiffEqExtrapolation/src/OrdinaryDiffEqExtrapolation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
module OrdinaryDiffEqExtrapolation

import OrdinaryDiffEq: alg_order, alg_maximum_order, get_current_adaptive_order,
get_current_alg_order, calculate_residuals!, accept_step_controller,
default_controller, beta2_default, beta1_default, gamma_default,
initialize!, perform_step!, @unpack, unwrap_alg, isthreaded,
step_accept_controller!, calculate_residuals,
OrdinaryDiffEqMutableCache, OrdinaryDiffEqConstantCache,
reset_alg_dependent_opts!, AbstractController,
step_accept_controller!, step_reject_controller!,
OrdinaryDiffEqAdaptiveAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
alg_cache, CompiledFloats, @threaded, stepsize_controller!, DEFAULT_PRECS,
constvalue, PolyesterThreads, Sequential, BaseThreads,
_digest_beta1_beta2, timedepentdtmin, _unwrap_val,
TimeDerivativeWrapper, UDerivativeWrapper, calc_J, _reshape, _vec,
WOperator, TimeGradientWrapper, UJacobianWrapper, build_grad_config,
build_jac_config, calc_J!, jacobian2W!, dolinsolve
using DiffEqBase, FastBroadcast, Polyester, MuladdMacro, RecursiveArrayTools, LinearSolve

macro cache(expr)
name = expr.args[2].args[1].args[1]
fields = [x for x in expr.args[3].args if typeof(x) != LineNumberNode]
cache_vars = Expr[]
jac_vars = Pair{Symbol, Expr}[]
for x in fields
if x.args[2] == :uType || x.args[2] == :rateType ||
x.args[2] == :kType || x.args[2] == :uNoUnitsType
push!(cache_vars, :(c.$(x.args[1])))
elseif x.args[2] == :DiffCacheType
push!(cache_vars, :(c.$(x.args[1]).du))
push!(cache_vars, :(c.$(x.args[1]).dual_du))
end
end
quote
$(esc(expr))
$(esc(:full_cache))(c::$name) = tuple($(cache_vars...))
end
end

include("algorithms.jl")
include("alg_utils.jl")
include("controllers.jl")
include("extrapolation_caches.jl")
include("extrapolation_perform_step.jl")

@inline function DiffEqBase.get_tmp_cache(integrator,
alg::OrdinaryDiffEqImplicitExtrapolationAlgorithm,
cache::OrdinaryDiffEqMutableCache)
(cache.tmp, cache.utilde)
end

export AitkenNeville, ExtrapolationMidpointDeuflhard, ExtrapolationMidpointHairerWanner,
ImplicitEulerExtrapolation,
ImplicitDeuflhardExtrapolation, ImplicitHairerWannerExtrapolation,
ImplicitEulerBarycentricExtrapolation

end
81 changes: 81 additions & 0 deletions lib/OrdinaryDiffEqExtrapolation/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
alg_order(alg::AitkenNeville) = alg.init_order
alg_maximum_order(alg::ExtrapolationMidpointDeuflhard) = 2(alg.max_order + 1)

function get_current_adaptive_order(
alg::OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm,
cache)
cache.cur_order
end
function get_current_alg_order(alg::OrdinaryDiffEqExtrapolationVarOrderVarStepAlgorithm,
cache)
cache.cur_order
end
get_current_alg_order(alg::ExtrapolationMidpointDeuflhard, cache) = 2(cache.n_curr + 1)
get_current_alg_order(alg::ImplicitDeuflhardExtrapolation, cache) = 2(cache.n_curr + 1)
get_current_adaptive_order(alg::ExtrapolationMidpointDeuflhard, cache) = 2cache.n_curr
get_current_adaptive_order(alg::ImplicitDeuflhardExtrapolation, cache) = 2cache.n_curr
get_current_alg_order(alg::ExtrapolationMidpointHairerWanner, cache) = 2(cache.n_curr + 1)
get_current_alg_order(alg::ImplicitHairerWannerExtrapolation, cache) = 2(cache.n_curr + 1)
get_current_alg_order(alg::ImplicitEulerBarycentricExtrapolation, cache) = cache.n_curr
get_current_alg_order(alg::ImplicitEulerExtrapolation, cache) = cache.n_curr + 1
get_current_adaptive_order(alg::ExtrapolationMidpointHairerWanner, cache) = 2cache.n_curr
get_current_adaptive_order(alg::ImplicitHairerWannerExtrapolation, cache) = 2cache.n_curr
get_current_adaptive_order(alg::ImplicitEulerExtrapolation, cache) = cache.n_curr - 1
function get_current_adaptive_order(
alg::ImplicitEulerBarycentricExtrapolation, cache)
cache.n_curr - 2
end

alg_maximum_order(alg::ImplicitDeuflhardExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ExtrapolationMidpointHairerWanner) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitHairerWannerExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitEulerExtrapolation) = 2(alg.max_order + 1)
alg_maximum_order(alg::ImplicitEulerBarycentricExtrapolation) = alg.max_order

function default_controller(
alg::Union{ExtrapolationMidpointDeuflhard,
ImplicitDeuflhardExtrapolation,
ExtrapolationMidpointHairerWanner,
ImplicitHairerWannerExtrapolation,
ImplicitEulerExtrapolation,
ImplicitEulerBarycentricExtrapolation},
cache,
qoldinit, _beta1 = nothing, _beta2 = nothing)
QT = typeof(qoldinit)
beta1, beta2 = _digest_beta1_beta2(alg, cache, Val(QT), _beta1, _beta2)
return ExtrapolationController(beta1)
end

beta2_default(alg::ExtrapolationMidpointDeuflhard) = 0 // 1
beta2_default(alg::ImplicitDeuflhardExtrapolation) = 0 // 1
beta2_default(alg::ExtrapolationMidpointHairerWanner) = 0 // 1
beta2_default(alg::ImplicitHairerWannerExtrapolation) = 0 // 1
beta2_default(alg::ImplicitEulerExtrapolation) = 0 // 1
beta2_default(alg::ImplicitEulerBarycentricExtrapolation) = 0 // 1

beta1_default(alg::ExtrapolationMidpointDeuflhard, beta2) = 1 // (2alg.init_order + 1)
beta1_default(alg::ImplicitDeuflhardExtrapolation, beta2) = 1 // (2alg.init_order + 1)
beta1_default(alg::ExtrapolationMidpointHairerWanner, beta2) = 1 // (2alg.init_order + 1)
beta1_default(alg::ImplicitHairerWannerExtrapolation, beta2) = 1 // (2alg.init_order + 1)
beta1_default(alg::ImplicitEulerExtrapolation, beta2) = 1 // (alg.init_order + 1)
beta1_default(alg::ImplicitEulerBarycentricExtrapolation, beta2) = 1 // (alg.init_order - 1)

function gamma_default(alg::ExtrapolationMidpointDeuflhard)
(1 // 4)^beta1_default(alg, beta2_default(alg))
end
function gamma_default(alg::ImplicitDeuflhardExtrapolation)
(1 // 4)^beta1_default(alg, beta2_default(alg))
end
function gamma_default(alg::ExtrapolationMidpointHairerWanner)
(65 // 100)^beta1_default(alg, beta2_default(alg))
end
function gamma_default(alg::ImplicitHairerWannerExtrapolation)
(65 // 100)^beta1_default(alg, beta2_default(alg))
end
function gamma_default(alg::ImplicitEulerExtrapolation)
(65 // 100)^beta1_default(alg, beta2_default(alg))
end

function gamma_default(alg::ImplicitEulerBarycentricExtrapolation)
(80 // 100)^beta1_default(alg, beta2_default(alg))
end
Loading

0 comments on commit 8db5efe

Please sign in to comment.