Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient by zygote of scalar expectation in batch mode #49

Open
ArnoStrouwen opened this issue Mar 25, 2021 · 5 comments
Open

Gradient by zygote of scalar expectation in batch mode #49

ArnoStrouwen opened this issue Mar 25, 2021 · 5 comments

Comments

@ArnoStrouwen
Copy link
Member

using ForwardDiff
using Zygote
using OrdinaryDiffEq
using DiffEqUncertainty
using DiffEqSensitivity
using Distributions
using Cubature

function f!(du,u,p,t)
    du[1] = p[1]*u[1] - p[2]*u[1]*u[2] #prey
    du[2] = -p[3]*u[2] + p[4]*u[1]*u[2] #predator
end

tspan = (0.0,10.0)
u0 = [1.0;1.0]
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(f!,u0,tspan,p,sensealg=InterpolatingAdjoint())
g(sol) = sol[1,end]

p1_3 = [1.5,3.0]
function testf!(p1_3)
    p_dist = [p1_3[1],1.0,p1_3[2],truncated(Normal(1.0,.1),.6, 1.4)]
    u0_dist = [1.0, Uniform(0.8, 1.1)]
    expectation(g, prob, u0_dist, p_dist, Koopman(), Tsit5(), quadalg=CubatureJLp(),batch=32)[1]
end
testf!(p1_3)
ForwardDiff.gradient(testf!,p1_3)
Zygote.gradient(testf!,p1_3)
@agerlach
Copy link
Contributor

agerlach commented Mar 25, 2021

I am able to run this but it looks like the gradients don't match. They do match for batch=0. Is that what you observed? From slack I got the impression that zygote was erroring.

@ArnoStrouwen
Copy link
Member Author

ERROR: MethodError: no method matching RecursiveArrayTools.VectorOfArray(::Tuple{Float64})
status `~/Dropbox/julia/small projects/dynamic experimental design tutorial/Project.toml`
  [667455a9] Cubature v1.5.1
  [aae7a2af] DiffEqFlux v1.34.1
  [41bf760c] DiffEqSensitivity v6.43.1
  [ef61062a] DiffEqUncertainty v1.8.0
  [31c24e10] Distributions v0.24.15
  [6a86dc24] FiniteDiff v2.8.0
  [587475ba] Flux v0.11.6
  [f6369f11] ForwardDiff v0.10.17
  [429524aa] Optim v1.3.0
  [1dea7af3] OrdinaryDiffEq v5.52.2
  [91a5bcdd] Plots v1.11.0
  [37e2e3b7] ReverseDiff v1.7.0
  [e88e6eb3] Zygote v0.6.6

@ChrisRackauckas
Copy link
Member

Interesting. Doesn't look the same as SciML/Integrals.jl#49 . I wonder if we can isolate this down to just Quadrature.jl

@agerlach
Copy link
Contributor

@ArnoStrouwen Is this in a new repl? Something doesn't seem right b/c RecursiveArrayTools isn't in Quadrature or DiffEqUncertainty. I just searched the repos to makes sure

This runs for me w/

  [667455a9] Cubature v1.5.1
  [41bf760c] DiffEqSensitivity v6.43.1
  [ef61062a] DiffEqUncertainty v1.8.0
  [31c24e10] Distributions v0.24.15
  [6a86dc24] FiniteDiff v2.8.0
  [f6369f11] ForwardDiff v0.10.17
  [1dea7af3] OrdinaryDiffEq v5.52.2
  [e88e6eb3] Zygote v0.6.6

But I get

ForwardDiff: 1.8264648587604257, 0.5292345598328614
Zygote: 7.146998823274198 , 1.8884006857029283
FiniteDiff: 1.6546890304376016, 3.702681873199691

@ArnoStrouwen
Copy link
Member Author

It might be something Visual Studio Code loads, since I get the same results as you when I run from terminal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants