Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EnsembleProblem with timepoint_meanvar errors due to internal mutation #446

Open
CharlesRSmith44 opened this issue Nov 18, 2020 · 6 comments

Comments

@CharlesRSmith44
Copy link

Hi,

Is there a way to combine the ensembleproblem command to solve SDEs in combination with Flux Neural Nets. Right now, when I try to combine them, I get the error when I try to train the Neural Net. "Info: Epoch 1
└ @ Main /home/ec2-user/.julia/packages/Flux/Fj3bt/src/optimise/train.jl:121
Mutating arrays is not supported

Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#459#460")(::Nothing) at /home/ec2-user/.julia/packages/Zygote/1GXzF/src/lib/array.jl:67
[3] (::Zygote.var"#1009#back#461"{Zygote.var"#459#460"})(::Nothing) at /home/ec2-user/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] materialize! at ./broadcast.jl:826 [inlined]"

Here is the code:


### Packages
import Pkg; Pkg.build("DifferentialEquations")
import Pkg; Pkg.add("DifferentialEquations")

using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity

# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=Float32[1.0]
v=Float32[1.0]
true_y0 = exp(T)*m[1] + (exp(T)-1)*( sqrt(  2*v[1]/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v[1]/(1-exp(-2*T)))
u0 = Float32[true_y0 + 0.25] # initial guess
Z = Float32[true_Z + 0.25]
ps = Flux.params(Z,u0) # or Flux.params(p1,u0) if you want to also optimize over all u0 parameters

function drift_nn(u,p,t)
    [-(u[1] + p[1] + 1.0f0)]
end

function stoch_nn(u,p,t)
    [p[1]]
end

probSDE_nn = SDEProblem{false}(drift_nn,stoch_nn,u0,tspan)
ensemble_prob = EnsembleProblem(probSDE_nn)
function yT_nn(p)
    sim = solve(ensemble_prob,EM(),trajectories=1000,p = p[1][1],u0=p[2],dt = 0.005)
end

function  loss_nn(m,v)
    m_val,v_val = timepoint_meanvar(yT_nn(ps),1.0)
    loss = (m_val[1]-m[1])^2 + (v_val[1]-v[1])^2
end

yT_nn(ps)

cb2 = function() #callback function to observe training
  display(loss_nn(m,v))
end

cb2()

Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
num_cycles = 1
data = ncycle([(m, v)], num_cycles)

opt = ADAM(0.01)
@epochs 5 Flux.train!(loss_nn, ps , data,  opt, cb=cb2);




Thank you!

@CharlesRSmith44
Copy link
Author

CharlesRSmith44 commented Nov 20, 2020

Alternatively, I can try to parallelize the SDE paths by using a pmap or map loop, but that seems to run into similar issues:

Example code:



### Using Packages
using Tables, DataFrames, Debugger
using Zygote, Statistics, Test, Tracker
using ModelingToolkit
using Flux, DifferentialEquations
using Flux: @epochs
using Random
using IterTools: ncycle
using DiffEqSensitivity
using DiffEqFlux
using DiffEqSensitivity
using Distributed

### User inputs
T = 0.01f0
tspan = (0.0f0,T)
m=Float32[1.0]
v=Float32[50.0]
true_y0 = exp(T)*m[1] + (exp(T)-1)*( sqrt(  2*v[1]/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v[1]/(1-exp(-2*T)))
u0 = Float32[true_y0 + 0.25] # initial guess
Z = Float32[true_Z + 0.25]
ps = Flux.params(Z,u0) |> gpu # or Flux.params(p1,u0) if you want to also optimize over all u0 parameters

function drift_nn(u,p,t)
    [-(u[1] + p[1] + 1.0f0)]
end

function stoch_nn(u,p,t)
    [p[1]]
end

probSDE_nn = SDEProblem{false}(drift_nn,stoch_nn,u0,tspan,ps)

num_sim = 100 # number of paths to draw

function yT_nn(u0,p)
    #sensealg=TrackerAdjoint()
    #sensealg = ReverseDiffAdjoint()
    [Array(solve(probSDE_nn,EM(),dt=0.005,p=Z,u0=u0,save_start=false,saveat=T,save_noise=false))[end] for j=1:num_sim]
end

### using pmap
@everywhere  function yT_nn2(u0,Z)
   (solve(probSDE_nn,EM(),dt=0.005,p=Z,u0=u0,save_start=false,saveat=T,save_noise=false))[end]
end

function compute_paths(u0,Z)
    u0_vals = fill(u0, num_sim) #u0_vals = repeat(u0,num_sim)
    p_vals = fill(Z, num_sim)
    paths = map(yT_nn2, u0_vals,p_vals) #pmap(yT_nn2,u0_vals,p_vals)
end

function loss_nn(m,v)
    paths = compute_paths(u0,Z)
    loss = (mean(paths)[1]-m[1])^2 + (var(paths)[1]-v[1])^2
end

compute_paths(u0, Z)

cb2 = function() #callback function to observe training
  display(loss_nn(m,v))
end
cb2()
Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
num_cycles = 1
data = ncycle([(m, v)], num_cycles)

# train
num_sim = 1000
opt = ADAM(0.01)
@epochs 5 Flux.train!(loss_nn, ps , data,  opt, cb=cb2);


The exact error is: MethodError: Cannot convert an object of type Float64 to an object of type Array{Float32,1}

@ChrisRackauckas
Copy link
Member

https://diffeqflux.sciml.ai/dev/examples/optimization_sde/#Example-1:-Fitting-Data-with-SDEs-1

That example uses EnsembleThreads without a problem, so we know it's not a problem. What your example points to is that timepoint_meanvar internally does mutation, and that's where the adjoint is erroring. If you don't use that function it works.

@ChrisRackauckas ChrisRackauckas changed the title EnsembleProblem with Mutating Arrays EnsembleProblem with timepoint_meanvar errors due to internal mutation Nov 20, 2020
@CharlesRSmith44
Copy link
Author

I copied the example you gave and it works just fine, but when I apply it to my problem I still seem to get errors:
MethodError: no method matching (::var"#46#47")(::Array{Float64,1}, ::Float64)


### Packages
import Pkg; Pkg.build("DifferentialEquations")
import Pkg; Pkg.add("DifferentialEquations")
import Pkg; Pkg.add("IterTools")

using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity
using Statistics
using DiffEqFlux

# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=1.0
v=1.0
true_y0 = exp(T)*m + (exp(T)-1)*( sqrt(  2*v/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v/(1-exp(-2*T)))
u0 = [true_y0 + 0.25] # initial guess
Z = true_Z + 0.25
p = [Z]

function drift_nn!(du,u,p,t)
    du[1] = - u[1] - p[1] - 1.0f0
end

function stoch_nn!(du,u,p,t)
    du[1] = p[1]
end

probSDE_nn = SDEProblem(drift_nn!,stoch_nn!,u0,tspan,p)
ensemble_prob = EnsembleProblem(probSDE_nn)
sol = solve(ensemble_prob,EM(),saveat = T,trajectories=1000, dt=0.005)

function loss_2(p)
    tmp_prob = remake(probSDE_nn,p=p)
    ensemble_prob = EnsembleProblem(tmp_prob)
    sim = solve(ensemble_prob,EM(),saveat = T, trajectories=1000, dt=0.005)
    arraysol = Array(sim)
    loss = (m - mean(arrsol,dims=3)[2])^2 + (v - var(arrsol,dims=3)[2])^2
end

cb2 = function() #callback function to observe training
  display(loss_2(p))
end

cb2()

Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
opt = ADAM(0.01)
@time res = DiffEqFlux.sciml_train(loss_2,ps,opt, maxiters = 5, cb=cb2)

@ChrisRackauckas
Copy link
Member

The error message was telling you that you didn't have enough arguments in your callback:

using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity
using Statistics
using DiffEqFlux

# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=1.0
v=1.0
true_y0 = exp(T)*m + (exp(T)-1)*( sqrt(  2*v/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v/(1-exp(-2*T)))
u0 = [true_y0 + 0.25] # initial guess
Z = true_Z + 0.25
p = [Z]

function drift_nn!(du,u,p,t)
    du[1] = - u[1] - p[1] - 1.0f0
end

function stoch_nn!(du,u,p,t)
    du[1] = p[1]
end

probSDE_nn = SDEProblem(drift_nn!,stoch_nn!,u0,tspan,p)
ensemble_prob = EnsembleProblem(probSDE_nn)
sol = solve(ensemble_prob,EM(),saveat = T,trajectories=1000, dt=0.005)

function loss_2(p)
    tmp_prob = remake(probSDE_nn,p=p)
    ensemble_prob = EnsembleProblem(tmp_prob)
    sim = solve(ensemble_prob,EM(),saveat = T, trajectories=1000, dt=0.005)
    arraysol = Array(sim)
    loss = (m - mean(arraysol,dims=3)[2])^2 + (v - var(arraysol,dims=3)[2])^2
end

cb2 = function(l,p) #callback function to observe training
  display(loss_2(p))
  false
end

Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
opt = ADAM(0.01)
@time res = DiffEqFlux.sciml_train(loss_2,p,opt, maxiters = 5, cb=cb2)

@CharlesRSmith44
Copy link
Author

Awesome, thank you so much!

One last question: How do I get the schiml_train function to update the intial position of the SDE (u0 in the example). My first guess trains, but the u0/p[2] parameter doesn't update.


### Packages
import Pkg; Pkg.build("DifferentialEquations")
import Pkg; Pkg.add("DifferentialEquations")
import Pkg; Pkg.add("IterTools")

using DifferentialEquations
using Flux
using DifferentialEquations.EnsembleAnalysis
using IterTools: ncycle
using Flux: @epochs
using DiffEqSensitivity
using Statistics
using DiffEqFlux

# Input parameters
T = 1.0f0
tspan = (0.0f0,T)
m=1.0
v=1.0
true_y0 = exp(T)*m + (exp(T)-1)*( sqrt(  2*v/(1-exp(-2*T)) )+1 )
true_Z = sqrt(2*v/(1-exp(-2*T)))
u0 = true_y0 + 0.25 # initial guess
Z = true_Z + 0.25
p = [Z, u0]

function drift_nn!(du,u,p,t)
    du[1] = - u[1] - p[1] - 1.0f0
end

function stoch_nn!(du,u,p,t)
    du[1] = p[1]
end

probSDE_nn = SDEProblem(drift_nn!,stoch_nn!,[p[2]],tspan,p)
ensemble_prob = EnsembleProblem(probSDE_nn)
sol = solve(ensemble_prob,EM(),saveat = T,trajectories=1000, dt=0.005)

function loss_2(p)
    tmp_prob = remake(probSDE_nn,p=p)
    ensemble_prob = EnsembleProblem(tmp_prob)
    sim = solve(ensemble_prob,EM(),saveat = T, trajectories=1000, dt=0.005)
    arrsol = Array(sim)
    loss = (m - mean(arrsol,dims=3)[2])^2 + (v - var(arrsol,dims=3)[2])^2
end

Z_pre = deepcopy(Z)
u0_pre = deepcopy(u0)

# train model on the same data num_cycle times
opt = ADAM(0.01)
@time res = DiffEqFlux.sciml_train(loss_2,p,opt, maxiters = 5)



@ChrisRackauckas
Copy link
Member

It's because you're not using p[2] from the loss function. You're fixing the value outside the loss function.

function loss_2(p)
    tmp_prob = remake(probSDE_nn,p=p,u0=[[p[2]]))
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants