Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feed known DE in neural DE #122

Closed
ric-cioffi opened this issue Jan 28, 2020 · 3 comments
Closed

Feed known DE in neural DE #122

ric-cioffi opened this issue Jan 28, 2020 · 3 comments

Comments

@ric-cioffi
Copy link
Contributor

First of all thank you for all of the work that went into this!
I've been trying to get the gist of DiffEqFlux and I'm having a hard time solving a seemingly trivial problem: I'm trying to solve a Neural SDE in which one of the inputs of the neural network is the solution to a given DE. Basically I have two series (Z, K) - that were generated jointly - where Z is independent of K. What I want is for the neural network to learn the evolution of K given that it knows exactly the evolution of Z.

In the following MWE I'll be using sensealg = TrackerAdjoint() because the actual problem is for an SDE which, if I understood correctly from the docs, leaves me only that.

using Plots
using Flux, DiffEqFlux, DifferentialEquations, DiffEqSensitivity
using DiffEqFlux.Tracker

Z̄ = 1.5f0
K̄ = 3.0f0
u0 = [Z̄; K̄]

mp       = Float32[0.2, 0.0]
datasize = 30
t_min    = 1.0f0
t_max    = 10.0f0
t_span   = (t_min, t_max)
t_range   = range(t_span[1], t_span[2], length = datasize)

training_t_span = (0.9*t_min, 1.1*t_max)

dZ_d(Z, K, t) = -0.1f0*(Z̄ - Z)^3 + 0.2f0*(Z̄ - Z) + cos(t)/((t - t_min) + 1)
dK_d(Z, K, t) = 2.0f0*Z^3 - 0.1f0*K^3                                    

dZ_n(Z, K, t) = mp[1].*Z
dK_n(Z, K, t) = mp[2].*K

# True SDE
function trueODE!(du, u, p, t)
    Z, K = u
    du[1], du[2] = dZ_d(Z, K, t), dK_d(Z, K, t)
end
function true_noise!(du, u, p, t)
    Z, K = u
    du[1], du[2] = dZ_n(Z, K, t), dK_n(Z, K, t)
end
prob_true = SDEProblem{true}(trueODE!, true_noise!, u0, training_t_span)
training_sol = solve(prob_true, SOSRI(); dense = true, reltol = 1e-1, abstol = 1e-1)
training_data = training_sol(t_range)

nn_drift = Chain(x -> (x).^3,
                Dense(2, 16, tanh),
                Dense(16, 1))

p_d, re_drift = Flux.destructure(nn_drift)
ps = Flux.params(p_d)

function sdeN(par)
    function dudt(u, p, t)
        Z = training_sol(t, idxs = 1) |> Float32
        n_input = [Z, u] #|> Tracker.collect

        du = re_drift(p)(n_input)[1]
        return du #|> Tracker.collect
    end

    prob = ODEProblem{false}(dudt, u0[2], t_span, par)
    return Array(concrete_solve(prob, SOSRI(), u0[2], par; sensealg = TrackerAdjoint(), saveat = t_range, reltol = 1e-1, abstol = 1e-1)) #|> Tracker.collect
end

predict_sdeN() = sdeN(p_d)
predict_sdeN()

opt = ADAM(1e-3)
loss_flux() = sum(abs2, training_data[2, :] .- predict_sdeN())

function cb(K_pred)
    K_train = training_data[2, :]

    println("Current loss: $(sum(abs2, K_train .- K_pred))")

    pl = plot(t_range, K_pred, lw = 5, label = "K prediction", xlabel = "Time")
         scatter!(pl, t_range, K_train, label = "K data", title = "Neural ODE: Training")
    display(plot(pl))
end
cb_flux() = cb(predict_sdeN())
Flux.train!(loss_flux, ps, Iterators.repeated((), 100), opt, cb = cb_flux)

However, this gives me the following MethodError (I leave out the stacktrace which is very unwieldy):

ERROR: MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})

I tried playing around with Tracker.collect or even with different sensealg but different errors pop up at different times and I'm not sure what I'm doing wrong and where.

@ric-cioffi
Copy link
Contributor Author

So, I figured the problem doesn't really have anything to do with feeding a known series in the neural DE.
There are in fact two (separate) issues that I don't fully understand:

  1. The first has to do with defining u0 as a scalar or an array (I have found a fix but I’d like to understand why the example isn’t working)
  2. The second, which is more important, has to do with using sensealg = TrackerAdjoint(), for which my problem doesn’t actually work (I’d like to run a neural SDE, which is why I need to use that)

I know this is a lot to read, but any help would be very much appreciated.

This first part (setup) is fine:

using DifferentialEquations, DiffEqFlux, DiffEqSensitivity, Flux

u0 = 1/2

datasize = 30
t_span = (t_min, t_max) = (0.0, 1.0)
t_range = range(t_min, t_max, length = datasize)
dt = (t_max - t_min)/datasize

true_ode(u, p, t) = 1.01*u

prob = ODEProblem{false}(true_ode, u0, t_span)
sol = solve(prob; dense = true)
sim = sol(t_range)

nn = Chain(Dense(1, 16, softplus),
                 Dense(16, 1))

p_nn, re_nn = Flux.destructure(nn)

So, as mentioned above the first problem has to do with using u0 as a scalar rather than an array.
What I find especially weird is that I can solve the neural ODE, but I cannot train the network.

function odeN(par)
    function dudt(u, p, t)
        return re_nn(p)([u[1]])[1]
    end
    prob = ODEProblem{false}(dudt, u0, t_span, par)
    return Array(concrete_solve(prob, Euler(), u0, par; dt = dt, saveat = t_range))
end
loss = θ -> sum(abs2, sim .- odeN(θ)[:])
l = loss(p_non)
cb = (θ, l) -> println("Current loss:", l)

# this works:
loss(p_nn) 

# this doesn't work:
DiffEqFlux.sciml_train!(loss, p_nn, ADAM(0.1), maxiters = 100, cb = cb)

ERROR: MethodError: no method matching similar(::Float64, ::Int64)
Closest candidates are:
  similar(::JuliaInterpreter.Compiled, ::Any) at /Users/Riccardo/.julia/packages/JuliaInterpreter/oKkmr/src/types.jl:7
  similar(::Array{T,N} where N, ::Int64) where T at array.jl:333
  similar(::BitArray, ::Int64...) at bitarray.jl:342
  ...
Stacktrace:
 [1] #ODEAdjointProblem#50(::Array{Float64,1}, ::CallbackSet{Tuple{},Tuple{}}, ::Float64, ::Float64, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}}, ::typeof(ODEAdjointProblem), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central}}, ::DiffEqSensitivity.var"#df#62"{Array{Float64,1}}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/interpolating_adjoint.jl:102
 [2] (::DiffEqSensitivity.var"#kw##ODEAdjointProblem")(::NamedTuple{(:checkpoints, :abstol, :reltol),Tuple{Array{Float64,1},Float64,Float64}}, ::typeof(ODEAdjointProblem), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central}}, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at ./none:0
 [3] #_adjoint_sensitivities#13(::Float64, ::Float64, ::Array{Float64,1}, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:dt,),Tuple{Float64}}}, ::typeof(DiffEqSensitivity._adjoint_sensitivities), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central}}, ::Euler, ::DiffEqSensitivity.var"#df#62"{Array{Float64,1}}, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/sensitivity_interface.jl:13
 [4] (::DiffEqSensitivity.var"#kw##_adjoint_sensitivities")(::NamedTuple{(:dt,),Tuple{Float64}}, ::typeof(DiffEqSensitivity._adjoint_sensitivities), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::InterpolatingAdjoint{0,true,Val{:central}}, ::Euler, ::Function, ::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Nothing) at ./none:0 (repeats 2 times)
 [5] #adjoint_sensitivities#12(::InterpolatingAdjoint{0,true,Val{:central}}, ::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:dt,),Tuple{Float64}}}, ::typeof(adjoint_sensitivities), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::Euler, ::Vararg{Any,N} where N) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/sensitivity_interface.jl:6
 [6] (::DiffEqSensitivity.var"#kw##adjoint_sensitivities")(::NamedTuple{(:sensealg, :dt),Tuple{InterpolatingAdjoint{0,true,Val{:central}},Float64}}, ::typeof(adjoint_sensitivities), ::ODESolution{Float64,1,Array{Float64,1},Nothing,Nothing,Array{Float64,1},Array{Array{Float64,1},1},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,var"#dudt#8",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{Float64,1},Array{Float64,1},Array{Array{Float64,1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats}, ::Euler, ::Vararg{Any,N} where N) at ./none:0
 [7] (::DiffEqSensitivity.var"#adjoint_sensitivity_backpass#61"{Euler,InterpolatingAdjoint{0,true,Val{:central}},Float64,Array{Float32,1},Tuple{}})(::Array{Float64,1}) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/concrete_solve.jl:69
 [8] #553#back at /Users/Riccardo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:55 [inlined]
 [9] odeN at ./REPL[13]:6 [inlined]
 [10] (::typeof(∂(odeN)))(::Array{Float64,1}) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [11] (::typeof(∂(#9)))(::Float64) at ./REPL[14]:1
 [12] #16 at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:24 [inlined]
 [13] (::typeof(∂(λ)))(::Float64) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [14] (::Zygote.var"#46#47"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:101
 [15] gradient(::Function, ::Zygote.Params) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:47
 [16] macro expansion at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:23 [inlined]
 [17] macro expansion at /Users/Riccardo/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [18] #sciml_train!#13(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train!), ::var"#9#10", ::Array{Float32,1}, ::ADAM) at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:22
 [19] (::DiffEqFlux.var"#kw##sciml_train!")(::NamedTuple{(:maxiters, :cb),Tuple{Int64,var"#11#12"}}, ::typeof(DiffEqFlux.sciml_train!), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
 [20] top-level scope at REPL[18]:2

I can easily solve this by wrapping u0 in an array in the return of odeN(par) and removing the final [1] indexing in the return of dudt(u, p, t) (as I will do in the next example with a different sensealg).
However, it is not clear to me why this comes up as an error, especially given odeN(p_nn) works just fine.

The second, more important, problem is that if I try to use sensealg = TrackerAdjoint() in odeN, it gives me an error about TrackedArray.
Also here what I find weird is that odeN works, but I cannot train the network:

function odeN(par)
    function dudt(u, p, t)
        return re_nn(p)([u[1]])
    end
    prob = ODEProblem{false}(dudt, u0, t_span, par)
    return Array(concrete_solve(prob, Euler(), [u0], par; sensealg = TrackerAdjoint(), dt = dt, saveat = t_range))
end
loss = θ -> sum(abs2, sim .- odeN(θ)[:])
l = loss(p_nn)
cb = (θ, l) -> println("Current loss:", l)

# this works:
loss(p_nn) 

# this doesn't work:
DiffEqFlux.sciml_train!(loss, p_nn, ADAM(0.1), maxiters = 100, cb = cb)

ERROR: Not implemented: convert TrackedArray{…,Array{Tracker.TrackedReal{Float64},1}} to TrackedArray{…,Array{Float64,1}}
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] convert(::Type{TrackedArray{…,Array{Float64,1}}}, ::TrackedArray{…,Array{Tracker.TrackedReal{Float64},1}}) at /Users/Riccardo/.julia/packages/Tracker/cpxco/src/lib/array.jl:38
 [3] setproperty!(::OrdinaryDiffEq.ODEIntegrator{Euler,false,TrackedArray{…,Array{Float64,1}},Float64,TrackedArray{…,Array{Float32,1}},Float64,Float64,Float64,Array{TrackedArray{…,Array{Float64,1}},1},ODESolution{Tracker.TrackedReal{Float64},2,Array{TrackedArray{…,Array{Float64,1}},1},Nothing,Nothing,Array{Float64,1},Array{Array{TrackedArray{…,Array{Float64,1}},1},1},ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{TrackedArray{…,Array{Float64,1}},1},Array{Float64,1},Array{Array{TrackedArray{…,Array{Float64,1}},1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.EulerConstantCache,OrdinaryDiffEq.DEOptions{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64},Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float64,1},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float64,1}},TrackedArray{…,Array{Float64,1}},Tracker.TrackedReal{Float64},Nothing}, ::Symbol, ::TrackedArray{…,Array{Tracker.TrackedReal{Float64},1}}) at ./Base.jl:21
 [4] initialize!(::OrdinaryDiffEq.ODEIntegrator{Euler,false,TrackedArray{…,Array{Float64,1}},Float64,TrackedArray{…,Array{Float32,1}},Float64,Float64,Float64,Array{TrackedArray{…,Array{Float64,1}},1},ODESolution{Tracker.TrackedReal{Float64},2,Array{TrackedArray{…,Array{Float64,1}},1},Nothing,Nothing,Array{Float64,1},Array{Array{TrackedArray{…,Array{Float64,1}},1},1},ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,OrdinaryDiffEq.InterpolationData{ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Array{TrackedArray{…,Array{Float64,1}},1},Array{Float64,1},Array{Array{TrackedArray{…,Array{Float64,1}},1},1},OrdinaryDiffEq.EulerConstantCache},DiffEqBase.DEStats},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},OrdinaryDiffEq.EulerConstantCache,OrdinaryDiffEq.DEOptions{Tracker.TrackedReal{Float64},Tracker.TrackedReal{Float64},Float64,Float64,typeof(DiffEqBase.ODE_DEFAULT_NORM),typeof(LinearAlgebra.opnorm),CallbackSet{Tuple{},Tuple{}},typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN),typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE),typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK),DataStructures.BinaryHeap{Float64,DataStructures.LessThan},DataStructures.BinaryHeap{Float64,DataStructures.LessThan},Nothing,Nothing,Int64,Array{Float64,1},StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}},Array{Float64,1}},TrackedArray{…,Array{Float64,1}},Tracker.TrackedReal{Float64},Nothing}, ::OrdinaryDiffEq.EulerConstantCache) at /Users/Riccardo/.julia/packages/OrdinaryDiffEq/nV9bA/src/perform_step/fixed_timestep_perform_step.jl:45
 [5] #__init#328(::StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}, ::Array{Float64,1}, ::Array{Float64,1}, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Nothing, ::Bool, ::Bool, ::Float64, ::Float64, ::Float64, ::Bool, ::Bool, ::Rational{Int64}, ::Nothing, ::Nothing, ::Rational{Int64}, ::Int64, ::Int64, ::Int64, ::Rational{Int64}, ::Bool, ::Int64, ::Nothing, ::Nothing, ::Int64, ::typeof(DiffEqBase.ODE_DEFAULT_NORM), ::typeof(LinearAlgebra.opnorm), ::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), ::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Int64, ::String, ::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler, ::Array{TrackedArray{…,Array{Float64,1}},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at /Users/Riccardo/.julia/packages/OrdinaryDiffEq/nV9bA/src/solve.jl:383
 [6] (::DiffEqBase.var"#kw##__init")(::NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}, ::typeof(DiffEqBase.__init), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler, ::Array{TrackedArray{…,Array{Float64,1}},1}, ::Array{Float64,1}, ::Array{Any,1}, ::Type{Val{true}}) at ./none:0 (repeats 4 times)
 [7] #__solve#327 at /Users/Riccardo/.julia/packages/OrdinaryDiffEq/nV9bA/src/solve.jl:4 [inlined]
 [8] #__solve at ./none:0 [inlined]
 [9] #solve_call#442(::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}}, ::typeof(DiffEqBase.solve_call), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler) at /Users/Riccardo/.julia/packages/DiffEqBase/YIwj5/src/solve.jl:40
 [10] #solve_call at ./none:0 [inlined]
 [11] #solve#443 at /Users/Riccardo/.julia/packages/DiffEqBase/YIwj5/src/solve.jl:57 [inlined]
 [12] (::DiffEqBase.var"#kw##solve")(::NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}, ::typeof(solve), ::ODEProblem{TrackedArray{…,Array{Float64,1}},Tuple{Float64,Float64},false,TrackedArray{…,Array{Float32,1}},ODEFunction{false,DiffEqSensitivity.var"#_f#84"{ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}},LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler) at ./none:0
 [13] (::DiffEqSensitivity.var"#tracker_adjoint_forwardpass#83"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,Tuple{}})(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float32,1}}) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/concrete_solve.jl:172
 [14] #20 at /Users/Riccardo/.julia/packages/Tracker/cpxco/src/back.jl:148 [inlined]
 [15] forward(::Tracker.var"#20#22"{DiffEqSensitivity.var"#tracker_adjoint_forwardpass#83"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}},ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem},Euler,Tuple{}}}, ::Tracker.Params) at /Users/Riccardo/.julia/packages/Tracker/cpxco/src/back.jl:135
 [16] forward(::Function, ::Array{Float64,1}, ::Array{Float32,1}) at /Users/Riccardo/.julia/packages/Tracker/cpxco/src/back.jl:148
 [17] #_concrete_solve_adjoint#81(::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:dt, :saveat),Tuple{Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}}, ::typeof(DiffEqBase._concrete_solve_adjoint), ::ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler, ::TrackerAdjoint, ::Array{Float64,1}, ::Array{Float32,1}) at /Users/Riccardo/.julia/packages/DiffEqSensitivity/9ybQN/src/local_sensitivity/concrete_solve.jl:182
 [18] #_concrete_solve_adjoint at ./none:0 [inlined]
 [19] #adjoint#452 at /Users/Riccardo/.julia/packages/DiffEqBase/YIwj5/src/solve.jl:210 [inlined]
 [20] #adjoint at ./none:0 [inlined]
 [21] _pullback(::Zygote.Context, ::DiffEqBase.var"#kw##concrete_solve", ::NamedTuple{(:sensealg, :dt, :saveat),Tuple{TrackerAdjoint,Float64,StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}}}, ::typeof(concrete_solve), ::ODEProblem{Float64,Tuple{Float64,Float64},false,Array{Float32,1},ODEFunction{false,var"#dudt#18",LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{,Tuple{}}},DiffEqBase.StandardODEProblem}, ::Euler, ::Array{Float64,1}, ::Array{Float32,1}) at /Users/Riccardo/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:53
 [22] odeN at ./REPL[25]:6 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(odeN), ::Array{Float32,1}) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [24] #19 at ./REPL[26]:1 [inlined]
 [25] _pullback(::Zygote.Context, ::var"#19#20", ::Array{Float32,1}) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [26] #16 at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:24 [inlined]
 [27] _pullback(::Zygote.Context, ::DiffEqFlux.var"#16#22"{var"#19#20",Array{Float32,1}}) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface2.jl:0
 [28] pullback(::Function, ::Zygote.Params) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:96
 [29] gradient(::Function, ::Zygote.Params) at /Users/Riccardo/.julia/packages/Zygote/tJj2w/src/compiler/interface.jl:46
 [30] macro expansion at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:23 [inlined]
 [31] macro expansion at /Users/Riccardo/.julia/packages/Juno/oLB1d/src/progress.jl:134 [inlined]
 [32] #sciml_train!#13(::Function, ::Int64, ::typeof(DiffEqFlux.sciml_train!), ::var"#19#20", ::Array{Float32,1}, ::ADAM) at /Users/Riccardo/.julia/packages/DiffEqFlux/ltHvf/src/train.jl:22
 [33] (::DiffEqFlux.var"#kw##sciml_train!")(::NamedTuple{(:maxiters, :cb),Tuple{Int64,var"#21#22"}}, ::typeof(DiffEqFlux.sciml_train!), ::Function, ::Array{Float32,1}, ::ADAM) at ./none:0
 [34] top-level scope at REPL[30]:2

By the way, I'm on Julia 1.3, and packages are: [email protected], [email protected], [email protected], [email protected].

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Feb 8, 2020

Hey,
I haven't been ignoring this, just hoping that the interface changes would help out some of this, and then taking the time to dig through your issue. So sorry for the late response! Now that the new interface is stabilized (pun intended, since it also works with highly stiff equations better 👍 ), it's time to tackle this.

What I find especially weird is that I can solve the neural ODE, but I cannot train the network.

The adjoint right now requires using in-place differential equations (Xref: SciML/SciMLSensitivity.jl#113 ). This can be fixed, but it's an extreme edge case where you want to use a Number or StaticArray in something that makes sense to use an adjoint with. Doesn't mean it doesn't make sense though, but that's why it hasn't been prioritized. The fact that f(u,p,t) (out of place) with arrays works in adjoints is somewhat of a nice hack that's useful for neural ODEs, but it can only work with mutable objects (like arrays).

The second one... I gotta run it to see.

@ChrisRackauckas
Copy link
Member

Tracker just has funny semantics sometimes. That's why we are trying to phase out the use of it, though right now indeed you need it for SDEs. You just have to be careful that scalar indexing produces a TrackedReal so [u[1]] is not the same as u here. The next issue is that some things were Float32 and others were Float64, which makes Tracker mad. Working code on the latest versions for this is as follows:

using OrdinaryDiffEq, DiffEqFlux, DiffEqSensitivity, Flux

u0 = Float32(1/2)

datasize = 30
t_span = (t_min, t_max) = (0f0, 1f0)
t_range = range(t_min, t_max, length = datasize)
dt = (t_max - t_min)/datasize

true_ode(u, p, t) = Float32(1.01)*u

prob = ODEProblem{false}(true_ode, u0, t_span)
sol = solve(prob, Tsit5(); dense = true)
sim = sol(t_range)

nn = Chain(Dense(1, 16, softplus),
                 Dense(16, 1))

p_nn, re_nn = Flux.destructure(nn)

function odeN(par)
    function dudt(u, p, t)
        return re_nn(p)(u)
    end
    prob = ODEProblem{false}(dudt, u0, t_span, par)
    return Array(concrete_solve(prob, Euler(), [u0], par; sensealg = TrackerAdjoint(), dt = dt, saveat = t_range))
end
loss = θ -> sum(abs2, sim .- odeN(θ)[:])
l = loss(p_nn)
cb = (θ, l) -> (println("Current loss:", l); false)
loss(p_nn)
DiffEqFlux.sciml_train(loss, p_nn, ADAM(0.1), maxiters = 300, cb = cb)

Let me know if you need any more help. Now to go implement SDE adjoints so we can stop recommending TrackerAdjoint 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants