Skip to content

Commit

Permalink
Update tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 11, 2023
1 parent acea56c commit ce78c55
Show file tree
Hide file tree
Showing 24 changed files with 265 additions and 349 deletions.
78 changes: 28 additions & 50 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -26,50 +18,36 @@ OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "4, 5"
Calculus = "0.5"
ComponentArrays = "0.15"
DataInterpolations = "3.10, 4"
DiffEqBase = "6.106"
DiffEqCallbacks = "2.24"
DiffEqFlux = "1.52, 2"
DiffEqNoiseProcess = "5.14"
DifferentialEquations = "7.6"
Documenter = "1"
Flux = "0.13, 0.14"
ForwardDiff = "0.10"
GraphNeuralNetworks = "0.5, 0.6"
IterTools = "1.4"
Lux = "0.5"
MLDatasets = "0.7"
NNlib = "0.9"
Optimisers = "0.3"
Optimization = "3.9"
OptimizationNLopt = "0.1"
OptimizationOptimJL = "0.1"
OptimizationOptimisers = "0.1"
OptimizationPolyalgorithms = "0.1"
OrdinaryDiffEq = "6.31"
Plots = "1.36"
QuadGK = "2.6"
RecursiveArrayTools = "2.32"
ReverseDiff = "1.14"
SciMLSensitivity = "7.11"
SimpleChains = "0.4"
StaticArrays = "1.5"
Statistics = "1"
StochasticDiffEq = "6.56"
Tracker = "0.2"
Zygote = "0.6"
Calculus = "0.5.1"
ComponentArrays = "0.15.5"
DataInterpolations = "4.5.0"
DelayDiffEq = "5.43.2"
DelimitedFiles = "1.9.1"
DiffEqCallbacks = "2.34.0"
DiffEqFlux = "3.0.0"
DiffEqNoiseProcess = "5.19.0"
ForwardDiff = "0.10.36"
Lux = "0.5.10"
LuxCUDA = "0.3.1"
Optimization = "3.19.3"
OptimizationNLopt = "0.1.8"
OptimizationOptimJL = "0.1.14"
OptimizationOptimisers = "0.1.6"
OptimizationPolyalgorithms = "0.1.2"
OrdinaryDiffEq = "6.60.0"
Plots = "1.39.0"
QuadGK = "2.9.1"
RecursiveArrayTools = "2.4.2"
ReverseDiff = "1.15.1"
Statistics = "1.11.0"
StochasticDiffEq = "6.63.2"
Tracker = "0.2.30"
Zygote = "0.6.67"
147 changes: 35 additions & 112 deletions docs/src/Benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ From our [recent papers](https://arxiv.org/abs/1812.01892), it's clear that `Enz
especially when the program is set up to be fully non-allocating mutating functions. Thus for all benchmarking,
especially with PDEs, this should be done. Neural network libraries don't make use of mutation effectively
[except for SimpleChains.jl](https://julialang.org/blog/2022/04/simple-chains/), so we recommend creating a
neural ODE / universal ODE with `ZygoteVJP` and Flux first, but then check the correctness by moving the
neural ODE / universal ODE with `ZygoteVJP` and Lux first, but then check the correctness by moving the
implementation over to SimpleChains and if possible `EnzymeVJP`. This can be an order of magnitude improvement
(or more) in many situations over all the previous benchmarks using Zygote and Flux, and thus it's
(or more) in many situations over all the previous benchmarks using Zygote and Lux, and thus it's
highly recommended in scenarios that require performance.

## Vs Torchdiffeq 1 million and less ODEs
Expand All @@ -33,12 +33,12 @@ at this time.
Quick summary:

- `BacksolveAdjoint` can be the fastest (but use with caution!); about 25% faster
- Using `ZygoteVJP` is faster than other vjp choices with FastDense due to the overloads
- Using `ZygoteVJP` is faster than other vjp choices for larger neural networks
- `ReverseDiffVJP(compile = true)` works well for small Lux neural networks

```julia
using DiffEqFlux,
OrdinaryDiffEq, Flux, Optim, Plots, SciMLSensitivity,
Zygote, BenchmarkTools, Random
OrdinaryDiffEq, Lux, SciMLSensitivity, Zygote, BenchmarkTools, Random, ComponentArrays

u0 = Float32[2.0; 0.0]
datasize = 30
Expand All @@ -53,116 +53,39 @@ end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = FastChain((x, p) -> x .^ 3,
FastDense(2, 50, tanh),
FastDense(50, 2))
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
Random.seed!(100)
p = initial_params(dudt2)

prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

function loss_neuralode(p)
pred = Array(prob_neuralode(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode, p)
# 2.709 ms (56506 allocations: 6.62 MiB)

prob_neuralode_interpolating = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_interpolating(p)
pred = Array(prob_neuralode_interpolating(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_interpolating, p)
# 5.501 ms (103835 allocations: 2.57 MiB)

prob_neuralode_interpolating_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()))

function loss_neuralode_interpolating_zygote(p)
pred = Array(prob_neuralode_interpolating_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_interpolating_zygote, p)
# 2.899 ms (56150 allocations: 6.61 MiB)

prob_neuralode_backsolve = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_backsolve(p)
pred = Array(prob_neuralode_backsolve(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve, p)
# 4.871 ms (85855 allocations: 2.20 MiB)

prob_neuralode_quad = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_quad(p)
pred = Array(prob_neuralode_quad(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_quad, p)
# 11.748 ms (79549 allocations: 3.87 MiB)

prob_neuralode_backsolve_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = TrackerVJP()))

function loss_neuralode_backsolve_tracker(p)
pred = Array(prob_neuralode_backsolve_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve_tracker, p)
# 27.604 ms (186143 allocations: 12.22 MiB)

prob_neuralode_backsolve_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))

function loss_neuralode_backsolve_zygote(p)
pred = Array(prob_neuralode_backsolve_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve_zygote, p)
# 2.091 ms (49883 allocations: 6.28 MiB)

prob_neuralode_backsolve_false = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)))

function loss_neuralode_backsolve_false(p)
pred = Array(prob_neuralode_backsolve_false(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
for sensealg in (InterpolatingAdjoint(autojacvec = ZygoteVJP()),
InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ZygoteVJP()),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)),
BacksolveAdjoint(autojacvec = TrackerVJP()),
QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)),
TrackerAdjoint())
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps,
sensealg = sensealg)
ps, st = Lux.setup(Random.default_rng(), prob_neuralode)
ps = ComponentArray(ps)

loss_neuralode = function (u0, p, st)
pred = Array(first(prob_neuralode(u0, p, st)))
loss = sum(abs2, ode_data .- pred)
return loss
end

t = @belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st)
println("$(sensealg) took $(t)s")
end

@btime Zygote.gradient(loss_neuralode_backsolve_false, p)
# 4.822 ms (9956 allocations: 1.03 MiB)

prob_neuralode_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = TrackerAdjoint())

function loss_neuralode_tracker(p)
pred = Array(prob_neuralode_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
# InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), false, false) took 0.029134224s
# InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), false, false) took 0.001657377s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), true, false) took 0.002477057s
# BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), true, false) took 0.031533335s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}}(ReverseDiffVJP{false}(), true, false) took 0.004605386s
# BacksolveAdjoint{0, true, Val{:central}, TrackerVJP}(TrackerVJP(false), true, false) took 0.044568018s
# QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), 1.0e-6, 0.001) took 0.002489559s
# TrackerAdjoint() took 0.003759097s

@btime Zygote.gradient(loss_neuralode_tracker, p)
# 12.614 ms (76346 allocations: 3.12 MiB)
```
8 changes: 4 additions & 4 deletions docs/src/examples/dae/physical_constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ terms must add to one. An example of this is as follows:

```@example dae
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
OrdinaryDiffEq, Plots
using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -74,7 +74,7 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)

```@example dae2
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
OrdinaryDiffEq, Plots
using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -133,8 +133,8 @@ Because this is a DAE, we need to make sure to use a **compatible solver**.
### Neural Network Layers

Next, we create our layers using `Lux.Chain`. We use this instead of `Flux.Chain` because it
is more suited to SciML applications (similarly for
`Lux.Dense`). The input to our network will be the initial conditions fed in as `u₀`.
is more suited to SciML applications (similarly for `Lux.Dense`). The input to our network
will be the initial conditions fed in as `u₀`.

```@example dae2
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Expand Down
7 changes: 3 additions & 4 deletions docs/src/examples/dde/delay_diffeq.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ supported. For example, we can build a layer with a delay differential equation
like:

```@example dde
using DifferentialEquations, Optimization, SciMLSensitivity,
OptimizationPolyalgorithms
using OrdinaryDiffEq, Optimization, SciMLSensitivity, OptimizationPolyalgorithms,
DelayDiffEq
# Define the same LV equation, but including a delay parameter
function delay_lotka_volterra!(du, u, h, p, t)
Expand All @@ -32,8 +32,7 @@ prob_dde = DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0),
function predict_dde(p)
return Array(solve(prob_dde, MethodOfSteps(Tsit5()),
u0 = u0, p = p, saveat = 0.1,
sensealg = ReverseDiffAdjoint()))
u0 = u0, p = p, saveat = 0.1, sensealg = ReverseDiffAdjoint()))
end
loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))
Expand Down
3 changes: 2 additions & 1 deletion docs/src/examples/hybrid_jump/bouncing_ball.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ data. Assume we have data for the ball's height after 15 seconds. Let's
first start by implementing the ODE:

```@example bouncing_ball
using Optimization, OptimizationPolyalgorithms, SciMLSensitivity, DifferentialEquations
using Optimization,
OptimizationPolyalgorithms, SciMLSensitivity, OrdinaryDiffEq, DiffEqCallbacks
function f(du, u, p, t)
du[1] = u[2]
Expand Down
Loading

0 comments on commit ce78c55

Please sign in to comment.