Skip to content

Commit

Permalink
Merge pull request #492 from rmsrosa/rode_solver
Browse files Browse the repository at this point in the history
RODE solver RandomHeun()
  • Loading branch information
ChrisRackauckas authored Aug 14, 2022
2 parents 41400cc + cb575ce commit 6b63618
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/StochasticDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ using DocStringExtensions
StochasticDiffEqRODECompositeAlgorithm

export RandomEM

export RandomHeun

export IteratedIntegralApprox, IICommutative, IILevyArea

Expand Down
1 change: 1 addition & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ alg_order(alg::IIF1Mil) = 1 // 1
alg_order(alg::EulerHeun) = 1 // 2
alg_order(alg::LambaEulerHeun) = 1 // 2
alg_order(alg::RandomEM) = 1 // 2
alg_order(alg::RandomHeun) = 1 // 2
alg_order(alg::SimplifiedEM) = 1 // 2
alg_order(alg::RKMil) = 1 // 1
alg_order(alg::RKMilCommute) = 1 // 1
Expand Down
2 changes: 2 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ end

struct RandomEM <: StochasticDiffEqRODEAlgorithm end

struct RandomHeun <: StochasticDiffEqRODEAlgorithm end

const SplitSDEAlgorithms = Union{IIF1M,IIF2M,IIF1Mil,SKenCarp,SplitEM}

@doc raw"""
Expand Down
17 changes: 17 additions & 0 deletions src/caches/basic_method_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,23 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
RandomEMCache(u,uprev,tmp,rtmp)
end

struct RandomHeunConstantCache <: StochasticDiffEqConstantCache end
@cache struct RandomHeunCache{uType,rateType,randType} <: StochasticDiffEqMutableCache
u::uType
uprev::uType
tmp::uType
rtmp1::rateType
rtmp2::rateType
wtmp::randType
end

alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{false}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits} = RandomHeunConstantCache()

function alg_cache(alg::RandomHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,jump_rate_prototype,::Type{uEltypeNoUnits},::Type{uBottomEltypeNoUnits},::Type{tTypeNoUnits},uprev,f,t,dt,::Type{Val{true}}) where {uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits}
tmp = zero(u); rtmp1 = zero(rate_prototype); rtmp2 = zero(rate_prototype); wtmp = zero(ΔW)
RandomHeunCache(u,uprev,tmp,rtmp1,rtmp2,wtmp)
end

struct SimplifiedEMConstantCache <: StochasticDiffEqConstantCache end
@cache struct SimplifiedEMCache{randType,uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
u::uType
Expand Down
23 changes: 23 additions & 0 deletions src/perform_step/low_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,29 @@ end
@.. u = uprev + dt * rtmp
end

@muladd function perform_step!(integrator,cache::RandomHeunConstantCache)
@unpack t,dt,uprev,u,W,p,f = integrator
ftmp = integrator.f(uprev,p,t,W.curW)
tmp = @.. uprev + dt * ftmp
wtmp = @.. W.curW + W.dW
u = uprev .+ (dt/2) .* (ftmp .+ integrator.f(tmp,p,t+dt, wtmp))
integrator.u = u
end

@muladd function perform_step!(integrator,cache::RandomHeunCache)
@unpack tmp, rtmp1, rtmp2, wtmp = cache
@unpack t,dt,uprev,u,W,p,f = integrator
integrator.f(rtmp1,uprev,p,t,W.curW)
@.. tmp = uprev + dt * rtmp1
if W.dW isa Number
wtmp = W.curW + W.dW
else
@.. wtmp = W.curW + W.dW
end
integrator.f(rtmp2,tmp,p,t+dt,wtmp)
@.. u = uprev + (dt/2) * (rtmp1 + rtmp2)
end

# weak approximation EM
@muladd function perform_step!(integrator,cache::SimplifiedEMConstantCache)
@unpack t,dt,uprev,u,W,p,f = integrator
Expand Down
31 changes: 24 additions & 7 deletions test/rode_linear_tests.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,44 @@
using StochasticDiffEq
using StochasticDiffEq, DiffEqNoiseProcess, Random, Test
Random.seed!(100)

f(u,p,t,W) = 1.01u.+0.87u.*W
u0 = 1.00
tspan = (0.0,1.0)
prob = RODEProblem(f,u0,tspan)
sol = solve(prob,RandomEM(),dt=1/100)
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
sol2 = solve(prob,RandomHeun(),dt=1/100)
@test abs(sol[end] - sol2[end]) < 0.1


f(du,u,p,t,W) = (du.=1.01u.+0.87u.*W)
u0 = ones(4)
prob = RODEProblem(f,u0,tspan)
sol = solve(prob,RandomEM(),dt=1/100)
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
sol2 = solve(prob,RandomHeun(),dt=1/100)
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])

f(u,p,t,W) = 2u*sin(W)
u0 = 1.00
tspan = (0.0,5.0)
prob = RODEProblem{false}(f,u0,tspan)
sol = solve(prob,RandomEM(),dt=1/100)
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
prob = RODEProblem{false}(f,u0,tspan, noise=NoiseWrapper(sol.W))
sol2 = solve(prob,RandomHeun(),dt=1/100)
@test abs(sol[end]-sol2[end]) < 0.1

function f(du,u,p,t,W)
du[1] = 2u[1]*sin(W[1] - W[2])
du[2] = -2u[2]*cos(W[1] + W[2])
end
u0 = [1.00;1.00]
tspan = (0.0,5.0)
tspan = (0.0,4.0)
prob = RODEProblem(f,u0,tspan)
sol = solve(prob,RandomEM(),dt=1/100)
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
sol2 = solve(prob,RandomHeun(),dt=1/100)
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])

function f(du,u,p,t,W)
du[1] = -2W[3]*u[1]*sin(W[1] - W[2])
Expand All @@ -33,4 +47,7 @@ end
u0 = [1.00;1.00]
tspan = (0.0,5.0)
prob = RODEProblem(f,u0,tspan,rand_prototype=zeros(3))
sol = solve(prob,RandomEM(),dt=1/100)
sol = solve(prob,RandomEM(),dt=1/100, save_noise=true)
prob = RODEProblem(f,u0,tspan, noise=NoiseWrapper(sol.W))
sol2 = solve(prob,RandomHeun(),dt=1/100)
@test sum(abs,sol[end]-sol2[end]) < 0.1 * sum(abs, sol[end])

0 comments on commit 6b63618

Please sign in to comment.