-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
solve and concrete_solve crash when called through sciml_train #153
Closed
Comments
I managed to clean up the example using DiffEqFlux
using OrdinaryDiffEq
using Flux
using Optim
function f!(du,u,p,t)
du .= p.*u
end
N = 10
p = 1
globalx = zeros(N)
x = ones(N)
prob = ODEProblem(f!, globalx , (0.,1.), p)
function predict(x,p)
x1 = concrete_solve(prob, Tsit5(), x .* collect(1:N), p, save_everystep=false)[end]
println(typeof(x1)) # direct call: Array{Float64,1}, call through sciml_train: Float64
println(x1[:])
return x1
end
function loss_adjoint(p)
prediction = predict(x,p)
loss = sum(abs2, prediction)
loss, prediction
end
loss_adjoint(-30.)
res = DiffEqFlux.sciml_train(loss_adjoint, [1.], BFGS(initial_stepnorm=0.01)) Zygote seems to handle the output of x1 = Array(concrete_solve(prob, Tsit5(), x .* collect(1:N), p, save_everystep=false))[:,end] Is this expected/desired behaviour? Is it documented somewhere why this happens? |
ChrisRackauckas
added a commit
to SciML/SciMLSensitivity.jl
that referenced
this issue
Feb 13, 2020
The adjoints should all commit to returning the same DiffEqArray as the forward pass, otherwise the semantics get confusing and different during the gradient passes. Fixes SciML/DiffEqFlux.jl#153
Undesirable and confusing! Fixed here: SciML/SciMLSensitivity.jl#194 |
Thanks for the report and sorry for the late response. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is just some simple example code for the problem. When calling
predict
directly,x1
is anArray{Float64,2}
, when called throughsciml_train
it isFloat64
and i get an index error in the line below. I have no idea why this happenswhen I use
solve
instead ofconcrete_solve
by replacing the relevant line withjulia crashes altogether. This is the error message
I am on julia 1.4 but the problem has been observed on 1.3 as well.
The text was updated successfully, but these errors were encountered: