-
-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SDEs gradients issues with EnsembleProblems #765
Comments
I the second case (remaking the problem with new MWE 3function diffeq(prob, p, t)
prob = remake(prob; tspan = (t[1], t[end]))
sol = solve(prob, SOSRI(); saveat = t)
return sum(sol)
end
p0 = Float32[1.25, 1.5, 1.75, 2]
gradient(p -> diffeq(prob, p, 0:0.01:1), p0)
Error 3ERROR: Gradient ChainRulesCore.ZeroTangent() should be a tuple
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] gradtuple1(x::ChainRulesCore.ZeroTangent)
@ ZygoteRules ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:24
[3] (::Zygote.var"#1794#back#155"{typeof(identity)})(Δ::ChainRulesCore.ZeroTangent)
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] Pullback
@ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:41 [inlined]
[5] (::typeof(∂(diffeq)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
[6] Pullback
@ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:47 [inlined]
[7] (::typeof(∂(#5)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface2.jl:0
[8] (::Zygote.var"#60#61"{typeof(∂(#5))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:45
[9] gradient(f::Function, args::Vector{Float32})
@ Zygote ~/.julia/packages/Zygote/PD12J/src/compiler/interface.jl:97
[10] top-level scope
@ ~/Documents/doctorado/issues/goku_net_on_SDEs/MWE2.jl:47
Although they have similarities, I realize that my first comment is tackling two different issues. Let me know if you prefer me to open a separate Github Issue or keep both here. |
I have seen this error before but couldn't get it to fix properly. I worked around it by https://github.com/SciML/DeepEquilibriumNetworks.jl/blob/ba6d66fcbdbd8bb2d39a5a27a3e4fced127aa584/experiments/src/DEQExperiments.jl#L16. But this is in no way the correct solution. I looks like Zygote generating an incorrect backward pass. |
Pullback Zygote to v0.6.43. There was a change to the accum derivative that broke a lot of code. I think this might be the same issue |
Hi! Thanks for your responses! Unfortunately, neither the workaround suggested by @avik-pal nor pining Zygote to v0.6.43 worked for any of the MWEs. @ChrisRackauckas, note that in contrast to the issue that you mentioned, if I don't load DiffEqFlux the codes work without any error, which is very strange (at least to me). |
The original MWE here is solved. The other two are because the derivative w.r.t. |
Strangely, the following codes work well if DiffEqFlux is not loaded but they break if it is.
MWE
Error
The error doesn't appear if only 2 sets of parameters are used for the ensemble. I also found another issue when I remake the problem for changing the time span, even if I don't take the gradient respect to it. The following MWE is the same as the previous one except that the parameters vector
p0
is of length 2 to avoid the previous error, a time range is added as argument and the problem is remade to change the time span. As in the previous case, it works without any errors if DiffEqFlux is not loaded.MWE 2
Error 2
Project and version info
The text was updated successfully, but these errors were encountered: