diff --git a/Project.toml b/Project.toml index 998e974eb..71eb74c64 100644 --- a/Project.toml +++ b/Project.toml @@ -34,9 +34,17 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [extensions] DistributionsExt = "Distributions" +MeasurementsExt = "Measurements" +MonteCarloMeasurementsExt = "MonteCarloMeasurements" +GeneralizedGeneratedExt = "GeneralizedGenerated" +UnitfulExt = "Unitful" [compat] ArrayInterfaceCore = "0.1.26" @@ -67,14 +75,18 @@ julia = "1.6" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" +Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" +MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Distributed", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"] +test = ["Distributed", "GeneralizedGenerated", "Measurements", "MonteCarloMeasurements", "Unitful", "LabelledArrays", "ForwardDiff", "InteractiveUtils", "Plots", "Pkg", "Random", "SafeTestsets", "Statistics", "Test", "Distributions"] diff --git a/ext/GeneralizedGeneratedExt.jl b/ext/GeneralizedGeneratedExt.jl new file mode 100644 index 000000000..472a40f49 --- /dev/null +++ b/ext/GeneralizedGeneratedExt.jl @@ -0,0 +1,10 @@ +module GeneralizedGeneratedExt + +using DiffEqBase +isdefined(Base, :get_extension) ? (using GeneralizedGenerated) : (using ..GeneralizedGenerated) + +function SciMLBase.numargs(::GeneralizedGenerated.RuntimeFn{Args}) where {Args} + GeneralizedGenerated.from_type(Args) |> length +end + +end diff --git a/ext/MeasurementsExt.jl b/ext/MeasurementsExt.jl new file mode 100644 index 000000000..ce35a23aa --- /dev/null +++ b/ext/MeasurementsExt.jl @@ -0,0 +1,34 @@ +module MeasurementsExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (using Measurements) : (using ..Measurements) + +function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement}, + p::AbstractArray{<:Measurements.Measurement}, t0) + u0 +end +DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) + +value(x::Type{Measurements.Measurement{T}}) where {T} = T +value(x::Measurements.Measurement) = Measurements.value(x) + +@inline DiffEqBase.fastpow(x::Measurements.Measurement, y::Measurements.Measurement) = x^y + +# Support adaptive steps should be errorless +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Measurements.Measurement, N + }, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Measurements.Measurement, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Measurements.Measurement, t) + abs(Measurements.value(u)) +end + +end diff --git a/ext/MonteCarloMeasurementsExt.jl b/ext/MonteCarloMeasurementsExt.jl new file mode 100644 index 000000000..a6787f935 --- /dev/null +++ b/ext/MonteCarloMeasurementsExt.jl @@ -0,0 +1,47 @@ +module MonteCarloMeasurementsExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (using MonteCarloMeasurements) : (using ..MonteCarloMeasurements) + +function DiffEqBase.promote_u0(u0::AbstractArray{<:MonteCarloMeasurements.AbstractParticles + }, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) + u0 +end +function DiffEqBase.promote_u0(u0, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) + eltype(p).(u0) +end + +DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T +DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) + +@inline function DiffEqBase.fastpow(x::MonteCarloMeasurements.AbstractParticles, + y::MonteCarloMeasurements.AbstractParticles) + x^y +end + +# Support adaptive steps should be errorless +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}, t) where {N} + sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t)))) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}, + t::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}) where {N} + sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(value.(t))))) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::MonteCarloMeasurements.AbstractParticles, t) + abs(value(u)) +end + +end diff --git a/ext/UnitfulExt.jl b/ext/UnitfulExt.jl new file mode 100644 index 000000000..6cf69719e --- /dev/null +++ b/ext/UnitfulExt.jl @@ -0,0 +1,26 @@ +module UnitfulExt + +using DiffEqBase +import DiffEqBase: value +isdefined(Base, :get_extension) ? (using Unitful) : (using ..Unitful) + +# Support adaptive errors should be errorless for exponentiation +value(x::Type{Unitful.AbstractQuantity{T, D, U}}) where {T, D, U} = T +value(x::Unitful.AbstractQuantity) = x.val +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Unitful.AbstractQuantity, N + }, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Unitful.AbstractQuantity, N}, + t) where {N} + sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), + zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) +end +@inline DiffEqBase.ODE_DEFAULT_NORM(u::Unitful.AbstractQuantity, t) = abs(value(u)) +@inline function DiffEqBase.UNITLESS_ABS2(x::Unitful.AbstractQuantity) + real(abs2(x) / oneunit(x) * oneunit(x)) +end + +end diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index 590fb8a75..7d5b0775e 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -1,6 +1,10 @@ module DiffEqBase -using Requires, ArrayInterfaceCore +if !isdefined(Base, :get_extension) + using Requires +end + +using ArrayInterfaceCore using StaticArrays # data arrays diff --git a/src/init.jl b/src/init.jl index 4fa6c75d9..a4bf2f501 100644 --- a/src/init.jl +++ b/src/init.jl @@ -15,100 +15,21 @@ function SciMLBase.tmap(args...) end function __init__() - @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin - function promote_u0(u0::AbstractArray{<:Measurements.Measurement}, - p::AbstractArray{<:Measurements.Measurement}, t0) - u0 + @static if !isdefined(Base, :get_extension) + @require Measurements="eff96d63-e80a-5855-80a2-b1b0885c5ab7" begin + include("../ext/MeasurementsExt.jl") end - promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) - value(x::Type{Measurements.Measurement{T}}) where {T} = T - value(x::Measurements.Measurement) = Measurements.value(x) - - @inline fastpow(x::Measurements.Measurement, y::Measurements.Measurement) = x^y - - # Support adaptive steps should be errorless - @inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Measurements.Measurement, N}, - t) where {N} - sqrt(sum(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) - end - @inline function ODE_DEFAULT_NORM(u::Array{<:Measurements.Measurement, N}, - t) where {N} - sqrt(sum(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) + @require MonteCarloMeasurements="0987c9cc-fe09-11e8-30f0-b96dd679fdca" begin + include("../ext/MonteCarloMeasurementsExt.jl") end - @inline function ODE_DEFAULT_NORM(u::Measurements.Measurement, t) - abs(Measurements.value(u)) - end - end - - @require MonteCarloMeasurements="0987c9cc-fe09-11e8-30f0-b96dd679fdca" begin - function promote_u0(u0::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - u0 - end - function promote_u0(u0, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) - eltype(p).(u0) - end - - value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where {T, N} = T - value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) - @inline function fastpow(x::MonteCarloMeasurements.AbstractParticles, - y::MonteCarloMeasurements.AbstractParticles) - x^y + @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" begin + include("../ext/UnitfulExt.jl") end - # Support adaptive steps should be errorless - @inline function ODE_DEFAULT_NORM(u::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}, t) where {N} - sqrt(mean(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(t)))) - end - @inline function ODE_DEFAULT_NORM(u::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}, - t::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}) where {N} - sqrt(mean(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(value.(t))))) - end - @inline function ODE_DEFAULT_NORM(u::MonteCarloMeasurements.AbstractParticles, t) - abs(value(u)) - end - end - - @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" begin - # Support adaptive errors should be errorless for exponentiation - value(x::Type{Unitful.AbstractQuantity{T, D, U}}) where {T, D, U} = T - value(x::Unitful.AbstractQuantity) = x.val - @inline function ODE_DEFAULT_NORM(u::AbstractArray{<:Unitful.AbstractQuantity, N}, - t) where {N} - sqrt(sum(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) - end - @inline function ODE_DEFAULT_NORM(u::Array{<:Unitful.AbstractQuantity, N}, - t) where {N} - sqrt(sum(x -> ODE_DEFAULT_NORM(x[1], x[2]), - zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) - end - @inline ODE_DEFAULT_NORM(u::Unitful.AbstractQuantity, t) = abs(value(u)) - @inline function UNITLESS_ABS2(x::Unitful.AbstractQuantity) - real(abs2(x) / oneunit(x) * oneunit(x)) + @require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin + include("../ext/GeneralizedGeneratedExt.jl") end end - - @require GeneralizedGenerated="6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb" begin function SciMLBase.numargs(::GeneralizedGenerated.RuntimeFn{ - Args - }) where { - Args - } - GeneralizedGenerated.from_type(Args) |> length - end end end