-
-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow NullParameters in more adjoint method dispatches #433
Comments
A few steps here. First of all, since you're doing function loss_1(S, f, f_z)
rows = size(S, 1)
loss = 0.0f0
for i = 1:rows
t = S[i, 1]
x = S[i, 2]
z = compute_z(x, f, f_z; saveat = [t])
z1,z2 = z[1][1],z[1][2]
u_t = dh(z1) * f(z1)
u_x = dh(z1) * z2
loss += (u_t + u_x)^2
end
return loss / rows
end That solves you issue. However, something else tricky comes up. You're trying to now differentiate w.r.t. parameters and initial condition, but there are not parameters. This seems to hit an untested edge case, but it was easy to fix: With that PR now it all works pretty fast. But you can make it even faster by forcing forward mode. |
First of all, thank you very much for the prompt response! As you suggested, I'm now indexing If you don't mind, though, I'd like clarify some of your comments:
solve(prob) # OK
solve(prob; sensealg=BacksolveAdjoint()) # OK
solve(prob; sensealg=QuadratureAdjoint()) # OK
solve(prob; sensealg=ZygoteAdjoint()) # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
solve(prob; sensealg=SensitivityADPassThrough()) # MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})
solve(prob; sensealg=ForwardSensitivity()) # MethodError: no method matching length(::SciMLBase.NullParameters)
solve(prob; sensealg=ForwardDiffSensitivity()) # MethodError: no method matching seed_duals(::SciMLBase.NullParameters,
solve(prob; sensealg=InterpolatingAdjoint()) # BoundsError: attempt to access 1-element Vector{Float32} at index [0]
solve(prob; sensealg=ReverseDiffAdjoint()) # MethodError: no method matching similar(::SciMLBase.NullParameters)
solve(prob; sensealg=TrackerAdjoint()) # MethodError: no method matching param(::SciMLBase.NullParameters) |
I'm going to kick this off to be a topic in DiffEqSensitivity about that specific topic (4).
You can do both parameters and initial condition. It'll work it out automatically. Just do stuff like: function f(theta)
u0 = theta[1:n]
p = theta[n+1:end]
_prob = remake(prob,u0=u0,p=p)
sol = solve(_prob,alg)
sum(abs2,sol - data)
end
Your ODE has no parameters. That's fine, that's an issue on our end to make the methods robust to not having any. Most of the time people do have parameters, so this is a less tested area, though it is tested for many cases: https://github.com/SciML/DiffEqSensitivity.jl/blob/master/test/null_parameters.jl
It'll do so automatically now, as of this week. After seeing your code, I realized we should just make a smart polyalgorithm solve your problem so that you don't have to do any work here.
Yes. The thing is that only BacksolveAdjoint, ForwardDiffSensitivity, and QuadratureAdjoint make sense when you have no parameters: InterpolatingAdjoint is strictly worse than QuadratureAdjoint for that case. So the default algorithm knows to avoid it, and I guess we never tested what happens if a user specifically asks for it in this case. Turns out it fails, so we should have a better behavior here (likely: automatically switch to QuadratureAdjoint). For |
Once again, thanks for your feedback! Things are clearer to me now. I'm glad I could help you identify these unknown problematic scenarios and improve the library. Keep up the good work! |
We are experimenting some models/architectures inspired by the NODE model. Given a point (t,x), the idea is to solve an ODE system whose definition uses a neural network (and also its derivative) and whose initial condition is x. Then, we take the ODE's solution z and evaluate z(t). In order to train the neural net, we sample some points (t,x) and for each of them compute a loss function which also solves the ODE system and evaluate z(t).
We managed to get the model working with the Optim library using the Nelder-Mead optimization method which does not require gradients, but the convergence is relatively slow. We are now trying to implement the model with DiffEqFlux to see if we get better performance with the gradient descent method. Another benefit of using DiffEqFlux is that we get automatic differentiation for "free", which is obviously required by SGD.
Unfortunately, we are getting weird errors deep in the AD code and we are having a hard time trying to overcome them. And, at this point, I really don't understand if the problem is in my usage of DiffEqFlux or if there is some bug in the library (or some of its dependencies).
I provided some sample code below for you to be able to reproduce the errors. When I run the Optim version (
run_optimization_optim(16)
) everything is fine, but I get "MethodError: no method matching fast_materialize(::Vector{Float32})" deep into Zygote code when running the DiffEqFlux version (run_optimization_flux()
). I also tried using different sensitivity algorithms, but then I get "MethodError: no method matching push!(::DataStructures.BinaryMinHeap{Float32})".Maybe you could provide some insight on what might be the issue here. Thanks!
The text was updated successfully, but these errors were encountered: