Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using StochasticDiffEq with custom types #576

Closed
apkille opened this issue Aug 1, 2024 · 0 comments · Fixed by #579
Closed

Using StochasticDiffEq with custom types #576

apkille opened this issue Aug 1, 2024 · 0 comments · Fixed by #579
Labels

Comments

@apkille
Copy link
Contributor

apkille commented Aug 1, 2024

Describe the bug 🐞

Inability to solve SDEProblem with a custom array type (say CustomArray) that is not a subtype of AbstractArray, as DiffEqBase.__init requires that you define Base.:(/)(x::CustomArray, y::CustomArray).

Minimal Reproducible Example 👇

I am modifying the test example: https://github.com/SciML/StochasticDiffEq.jl/blob/master/test/noindex_tests.jl so that it the custom type is not a subtype of AbstractArray, which requires a few more methods to be defined:

using StochasticDiffEq, OrdinaryDiffEq, LinearAlgebra, RecursiveArrayTools

struct CustomArray{T, N}
    x::Array{T, N}
end
Base.size(x::CustomArray) = size(x.x)
Base.axes(x::CustomArray) = axes(x.x)
Base.ndims(x::CustomArray) = ndims(x.x)
Base.ndims(::Type{<:CustomArray{T,N}}) where {T,N} = N
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
Base.zero(::Type{<:CustomArray{T,N}}) where {T,N} = CustomArray(zero(Array{T,N}))
Base.similar(x::CustomArray, dims::Union{Integer, AbstractUnitRange}...) = CustomArray(similar(x.x, dims...))
Base.copyto!(x::CustomArray, y::CustomArray) = CustomArray(copyto!(x.x, y.x))
Base.copy(x::CustomArray) = CustomArray(copy(x.x))
Base.length(x::CustomArray) = length(x.x)
Base.isempty(x::CustomArray) = isempty(x.x)
Base.eltype(x::CustomArray) = eltype(x.x)
Base.zero(x::CustomArray) = CustomArray(zero(x.x))
Base.fill!(x::CustomArray, y) = CustomArray(fill!(x.x, y))
Base.getindex(x::CustomArray, i) = getindex(x.x, i)
Base.setindex!(x::CustomArray, v, idx) = setindex!(x.x, v, idx)
Base.mapreduce(f, op, x::CustomArray; kwargs...) = mapreduce(f, op, x.x; kwargs...)
Base.any(f::Function, x::CustomArray; kwargs...) = any(f, x.x; kwargs...)
Base.all(f::Function, x::CustomArray; kwargs...) = all(f, x.x; kwargs...)
Base.:(==)(x::CustomArray, y::CustomArray) = x.x == y.x
Base.:(*)(x::Number, y::CustomArray) = CustomArray(x*y.x)
Base.:(/)(x::CustomArray, y::Number) = CustomArray(x.x/y)
LinearAlgebra.norm(x::CustomArray) = norm(x.x)

struct CustomStyle{N} <: Broadcast.BroadcastStyle where {N} end
CustomStyle(::Val{N}) where N = CustomStyle{N}()
CustomStyle{M}(::Val{N}) where {N,M} = NoIndexStyle{N}()
Base.BroadcastStyle(::Type{<:CustomArray{T,N}}) where {T,N} = CustomStyle{N}()
Broadcast.BroadcastStyle(::CustomStyle{N}, ::Broadcast.DefaultArrayStyle{0}) where {N} = CustomStyle{N}()
Base.similar(bc::Base.Broadcast.Broadcasted{CustomStyle{N}}, ::Type{ElType}) where {N, ElType} = CustomArray(similar(Array{ElType, N}, axes(bc)))
Base.Broadcast._broadcast_getindex(x::CustomArray, i) = x.x[i]
Base.Broadcast.extrude(x::CustomArray) = x
Base.Broadcast.broadcastable(x::CustomArray) = x

@inline function Base.copyto!(dest::CustomArray, bc::Base.Broadcast.Broadcasted{<:CustomStyle})
    axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
    bc′ = Base.Broadcast.preprocess(dest, bc)
    dest′ = dest.x
    @simd for I in 1:length(dest′)
        @inbounds dest′[I] = bc′[I]
    end
    return dest
end
@inline function Base.copy(bc::Base.Broadcast.Broadcasted{<:CustomStyle})
    bcf = Broadcast.flatten(bc)
    x = find_x(bcf)
    data = zeros(eltype(x), size(x))
    @inbounds @simd for I in 1:length(x)
        data[I] = bcf[I]
    end
    return CustomArray(data)
end
find_x(bc::Broadcast.Broadcasted) = find_x(bc.args)
find_x(args::Tuple) = find_x(find_x(args[1]), Base.tail(args))
find_x(x) = x
find_x(::Any, rest) = find_x(rest)
find_x(x::CustomArray, rest) = x.x

RecursiveArrayTools.recursive_unitless_bottom_eltype(x::CustomArray) = eltype(x)
RecursiveArrayTools.recursivecopy!(dest::CustomArray, src::CustomArray) = copyto!(dest, src)
RecursiveArrayTools.recursivecopy(x::CustomArray) = copy(x)
RecursiveArrayTools.recursivefill!(x::CustomArray, a) = fill!(x, a)

Base.show_vector(io::IO, x::CustomArray) = Base.show_vector(io, x.x)

Base.show(io::IO, x::CustomArray) = (print(io, "CustomArray");show(io, x.x))
function Base.show(io::IO, ::MIME"text/plain", x::CustomArray)
    println(io, Base.summary(x), ":")
    Base.print_array(io, x.x)
end

You can solve this defined type on ODEProblems, but not on SDEProblems:

ca0 = CustomArray(ones(10))
prob = SDEProblem((du, u, p, t)->copyto!(du, u),(du, u, p, t)->copyto!(du, u), ca0, (0.0,1.0))
sol = solve(prob, EM(), dt=1//2^4)

Error & Stacktrace ⚠️

ERROR: MethodError: no method matching /(::CustomArray{Float64, 1}, ::CustomArray{Float64, 1})

Closest candidates are:
  /(::ChainRulesCore.NotImplemented, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/notimplemented.jl:42
  /(::Any, ::ChainRulesCore.NotImplemented)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/notimplemented.jl:43
  /(::ChainRulesCore.AbstractZero, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/I1EbV/src/tangent_types/abstract_zero.jl:31
  ...

Stacktrace:
 [1] __init(_prob::SDEProblem{…}, alg::EM{…}, timeseries_init::Vector{…}, ts_init::Vector{…}, ks_init::Type, recompile::Type{…}; saveat::Tuple{}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Rational{…}, adaptive::Bool, gamma::Int64, abstol::Nothing, reltol::Nothing, qmin::Int64, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta2::Nothing, beta1::Nothing, qoldinit::Int64, controller::Nothing, fullnormalize::Bool, failfactor::Int64, delta::Rational{…}, maxiters::Int64, dtmax::Float64, dtmin::Float64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::@Kwargs{})
   @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/M3bKo/src/solve.jl:286
 [2] __solve(prob::SDEProblem{…}, alg::EM{…}, timeseries::Vector{…}, ts::Vector{…}, ks::Nothing, recompile::Type{…}; kwargs::@Kwargs{})
   @ StochasticDiffEq ~/.julia/packages/StochasticDiffEq/M3bKo/src/solve.jl:6
 [3] solve_call(_prob::SDEProblem{…}, args::EM{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
 [4] solve_up(prob::SDEProblem{…}, sensealg::Nothing, u0::CustomArray{…}, p::SciMLBase.NullParameters, args::EM{…}; kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080
 [5] solve(prob::SDEProblem{…}, args::EM{…}; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{})
   @ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003
 [6] top-level scope
   @ REPL[4]:1
Some type information was truncated. Use `show(err)` to see complete types.

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()
Status `~/Documents/Julia Packages/QuantumJulia/Project.toml`
  [4c88cf16] Aqua v0.8.7
  [4fba245c] ArrayInterface v7.14.0
  [6e4b80f9] BenchmarkTools v1.5.0
  [0c46a032] DifferentialEquations v7.13.0
  [ffbed154] DocStringExtensions v0.9.3
  [e30172f5] Documenter v1.5.0
  [daee34ce] DocumenterCitations v1.3.3
  [7a1cc6ca] FFTW v1.8.0
  [7034ab61] FastBroadcast v0.3.5
  [1a297f60] FillArrays v1.11.0
  [f6369f11] ForwardDiff v0.10.36
⌃ [e9467ef8] GLMakie v0.9.11
  [c3a54625] JET v0.9.7
  [8ac3fa9e] LRUCache v1.6.1
  [23fbe1c1] Latexify v0.16.4
  [16fef848] LiveServer v1.3.1
  [1914dd2f] MacroTools v0.5.13
  [f9640e96] MultiScaleArrays v1.12.0
  [1dea7af3] OrdinaryDiffEq v6.87.0
  [e4faabce] PProf v3.1.0
  [32113eaa] PkgBenchmark v0.2.12
  [d330b81b] PyPlot v2.11.5
  [0525e862] QuantumClifford v0.9.7
  [5717a53b] QuantumInterface v0.3.4 `QuantumInterface.jl`
  [6e0679c1] QuantumOptics v1.1.1 `QuantumOptics.jl`
  [4f57444f] QuantumOpticsBase v0.5.1 `QuantumOpticsBase.jl`
  [efa7fd63] QuantumSymbolics v0.3.4 `QuantumSymbolics.jl`
  [2576dda1] RandomMatrices v0.5.5
  [731186ca] RecursiveArrayTools v3.26.0
  [295af30f] Revise v3.5.17
  [1bc83da4] SafeTestsets v0.1.0
  [2913bbd2] StatsBase v0.34.3
  [789caeaf] StochasticDiffEq v6.67.0
  [5e0ebb24] Strided v2.1.0
  [4db3bf67] StridedViews v0.3.1
⌅ [d1185830] SymbolicUtils v2.1.2
  [0c5d862f] Symbolics v5.34.0
  [ade2ca70] Dates
  [37e2e46d] LinearAlgebra
  [9abbd945] Profile
  [2f01184e] SparseArrays v1.10.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`

Additional context

For context, I am working to define a broadcast interface for QuantumOptics.jl types (which wrap around arrays and basis information) that integrates with SciML.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant