Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End of Year Housekeeping #948

Merged
merged 9 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
/docs/build/
/docs/build/
test/Manifest.toml
test/gpu/Manifest.toml
37 changes: 5 additions & 32 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ADTypes = "0.1, 0.2"
Expand All @@ -61,52 +60,26 @@ FunctionProperties = "0.1"
FunctionWrappersWrappers = "0.1"
Functors = "0.4"
GPUArraysCore = "0.1"
LinearAlgebra = "<0.0.1, 1"
LinearSolve = "2"
Markdown = "<0.0.1, 1"
OrdinaryDiffEq = "6.19.1"
Parameters = "0.12"
PreallocationTools = "0.4.4"
QuadGK = "2.1"
Random = "<0.0.1, 1"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "2.4.2, 3"
Reexport = "0.2, 1.0"
ReverseDiff = "1.9"
SciMLBase = "1.66.0, 2"
SciMLOperators = "0.1, 0.2, 0.3"
SimpleNonlinearSolve = "0.1.8"
SparseDiffTools = "2.5"
StaticArrays = "1.8.0"
StaticArraysCore = "1.4"
Statistics = "1"
StochasticDiffEq = "6.20"
Tracker = "0.2"
TruncatedStacktraces = "1.2"
Zygote = "0.6"
ZygoteRules = "0.2"
julia = "1.9"

[extras]
AlgebraicMultigrid = "2169fc97-5a83-5252-b627-83903c6c433c"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationFlux = "253f991c-a7b2-45f8-8852-8b9a9df78a86"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLNLSolve = "e9a6253c-8580-4d32-9898-8661bb511710"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "ComponentArrays", "Calculus", "Distributed", "DelayDiffEq", "Optimization", "OptimizationFlux", "OptimizationOptimJL", "Flux", "ReverseDiff", "SafeTestsets", "StaticArrays", "Test", "Random", "Pkg", "SteadyStateDiffEq", "NLsolve", "NonlinearSolve", "SparseArrays", "SciMLNLSolve", "OptimizationOptimisers", "Functors", "Lux"]
38 changes: 11 additions & 27 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep one tutorial with Flux, kind of like we do with SimpleChains. It should be clearly second class citizen though, there mostly to keep alive the example of how to do something with Flux, but just as a "you can do things with Flux too!" but also put a warning on there saying Lux is greatly preferred.

ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationNLopt = "4e6fcdb7-1186-4e1f-a706-475e75c168bb"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -26,36 +19,28 @@ OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SimpleChains = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "4, 5"
Calculus = "0.5"
ComponentArrays = "0.15"
DataInterpolations = "3.10, 4"
DiffEqBase = "6.106"
# DiffEqBase = "6.106"
DiffEqCallbacks = "2.24"
DiffEqFlux = "1.52, 2"
DiffEqFlux = "3"
DiffEqNoiseProcess = "5.14"
DifferentialEquations = "7.6"
Documenter = "1"
Flux = "0.13, 0.14"
ForwardDiff = "0.10"
GraphNeuralNetworks = "0.5, 0.6"
IterTools = "1.4"
Lux = "0.5"
MLDatasets = "0.7"
NNlib = "0.9"
Optimisers = "0.3"
# GraphNeuralNetworks = "0.5, 0.6"
# IterTools = "1.4"
Lux = "0.5.7"
# MLDatasets = "0.7"
Optimization = "3.9"
OptimizationNLopt = "0.1"
OptimizationOptimJL = "0.1"
Expand All @@ -67,9 +52,8 @@ QuadGK = "2.6"
RecursiveArrayTools = "2.32"
ReverseDiff = "1.14"
SciMLSensitivity = "7.11"
SimpleChains = "0.4"
StaticArrays = "1.5"
# SimpleChains = "0.4"
Statistics = "1"
StochasticDiffEq = "6.56"
Tracker = "0.2"
Zygote = "0.6"
Zygote = "0.6"
147 changes: 35 additions & 112 deletions docs/src/Benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ From our [recent papers](https://arxiv.org/abs/1812.01892), it's clear that `Enz
especially when the program is set up to be fully non-allocating mutating functions. Thus for all benchmarking,
especially with PDEs, this should be done. Neural network libraries don't make use of mutation effectively
[except for SimpleChains.jl](https://julialang.org/blog/2022/04/simple-chains/), so we recommend creating a
neural ODE / universal ODE with `ZygoteVJP` and Flux first, but then check the correctness by moving the
neural ODE / universal ODE with `ZygoteVJP` and Lux first, but then check the correctness by moving the
implementation over to SimpleChains and if possible `EnzymeVJP`. This can be an order of magnitude improvement
(or more) in many situations over all the previous benchmarks using Zygote and Flux, and thus it's
(or more) in many situations over all the previous benchmarks using Zygote and Lux, and thus it's
highly recommended in scenarios that require performance.

## Vs Torchdiffeq 1 million and less ODEs
Expand All @@ -33,12 +33,12 @@ at this time.
Quick summary:

- `BacksolveAdjoint` can be the fastest (but use with caution!); about 25% faster
- Using `ZygoteVJP` is faster than other vjp choices with FastDense due to the overloads
- Using `ZygoteVJP` is faster than other vjp choices for larger neural networks
- `ReverseDiffVJP(compile = true)` works well for small Lux neural networks

```julia
using DiffEqFlux,
OrdinaryDiffEq, Flux, Optim, Plots, SciMLSensitivity,
Zygote, BenchmarkTools, Random
OrdinaryDiffEq, Lux, SciMLSensitivity, Zygote, BenchmarkTools, Random, ComponentArrays

u0 = Float32[2.0; 0.0]
datasize = 30
Expand All @@ -53,116 +53,39 @@ end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

dudt2 = FastChain((x, p) -> x .^ 3,
FastDense(2, 50, tanh),
FastDense(50, 2))
dudt2 = Chain(x -> x .^ 3, Dense(2, 50, tanh), Dense(50, 2))
Random.seed!(100)
p = initial_params(dudt2)

prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

function loss_neuralode(p)
pred = Array(prob_neuralode(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode, p)
# 2.709 ms (56506 allocations: 6.62 MiB)

prob_neuralode_interpolating = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_interpolating(p)
pred = Array(prob_neuralode_interpolating(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_interpolating, p)
# 5.501 ms (103835 allocations: 2.57 MiB)

prob_neuralode_interpolating_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = InterpolatingAdjoint(autojacvec = ZygoteVJP()))

function loss_neuralode_interpolating_zygote(p)
pred = Array(prob_neuralode_interpolating_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_interpolating_zygote, p)
# 2.899 ms (56150 allocations: 6.61 MiB)

prob_neuralode_backsolve = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_backsolve(p)
pred = Array(prob_neuralode_backsolve(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve, p)
# 4.871 ms (85855 allocations: 2.20 MiB)

prob_neuralode_quad = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)))

function loss_neuralode_quad(p)
pred = Array(prob_neuralode_quad(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_quad, p)
# 11.748 ms (79549 allocations: 3.87 MiB)

prob_neuralode_backsolve_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = TrackerVJP()))

function loss_neuralode_backsolve_tracker(p)
pred = Array(prob_neuralode_backsolve_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve_tracker, p)
# 27.604 ms (186143 allocations: 12.22 MiB)

prob_neuralode_backsolve_zygote = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))

function loss_neuralode_backsolve_zygote(p)
pred = Array(prob_neuralode_backsolve_zygote(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end

@btime Zygote.gradient(loss_neuralode_backsolve_zygote, p)
# 2.091 ms (49883 allocations: 6.28 MiB)

prob_neuralode_backsolve_false = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)))

function loss_neuralode_backsolve_false(p)
pred = Array(prob_neuralode_backsolve_false(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
for sensealg in (InterpolatingAdjoint(autojacvec = ZygoteVJP()),
InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(true)),
BacksolveAdjoint(autojacvec = ZygoteVJP()),
BacksolveAdjoint(autojacvec = ReverseDiffVJP(false)),
BacksolveAdjoint(autojacvec = TrackerVJP()),
QuadratureAdjoint(autojacvec = ReverseDiffVJP(true)),
TrackerAdjoint())
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(); saveat = tsteps,
sensealg = sensealg)
ps, st = Lux.setup(Random.default_rng(), prob_neuralode)
ps = ComponentArray(ps)

loss_neuralode = function (u0, p, st)
pred = Array(first(prob_neuralode(u0, p, st)))
loss = sum(abs2, ode_data .- pred)
return loss
end

t = @belapsed Zygote.gradient($loss_neuralode, $u0, $ps, $st)
println("$(sensealg) took $(t)s")
end

@btime Zygote.gradient(loss_neuralode_backsolve_false, p)
# 4.822 ms (9956 allocations: 1.03 MiB)

prob_neuralode_tracker = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps,
sensealg = TrackerAdjoint())

function loss_neuralode_tracker(p)
pred = Array(prob_neuralode_tracker(u0, p))
loss = sum(abs2, ode_data .- pred)
return loss
end
# InterpolatingAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), false, false) took 0.029134224s
# InterpolatingAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), false, false) took 0.001657377s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), true, false) took 0.002477057s
# BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}(ZygoteVJP(false), true, false) took 0.031533335s
# BacksolveAdjoint{0, true, Val{:central}, ReverseDiffVJP{false}}(ReverseDiffVJP{false}(), true, false) took 0.004605386s
# BacksolveAdjoint{0, true, Val{:central}, TrackerVJP}(TrackerVJP(false), true, false) took 0.044568018s
# QuadratureAdjoint{0, true, Val{:central}, ReverseDiffVJP{true}}(ReverseDiffVJP{true}(), 1.0e-6, 0.001) took 0.002489559s
# TrackerAdjoint() took 0.003759097s

@btime Zygote.gradient(loss_neuralode_tracker, p)
# 12.614 ms (76346 allocations: 3.12 MiB)
```
8 changes: 4 additions & 4 deletions docs/src/examples/dae/physical_constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ terms must add to one. An example of this is as follows:

```@example dae
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -74,7 +74,7 @@ result_stiff = Optimization.solve(optprob, NLopt.LD_LBFGS(), maxiters = 100)

```@example dae2
using Lux, ComponentArrays, DiffEqFlux, Optimization, OptimizationNLopt,
DifferentialEquations, Plots
OrdinaryDiffEq, Plots

using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -133,8 +133,8 @@ Because this is a DAE, we need to make sure to use a **compatible solver**.
### Neural Network Layers

Next, we create our layers using `Lux.Chain`. We use this instead of `Flux.Chain` because it
is more suited to SciML applications (similarly for
`Lux.Dense`). The input to our network will be the initial conditions fed in as `u₀`.
is more suited to SciML applications (similarly for `Lux.Dense`). The input to our network
will be the initial conditions fed in as `u₀`.

```@example dae2
nn_dudt2 = Lux.Chain(Lux.Dense(3, 64, tanh),
Expand Down
7 changes: 3 additions & 4 deletions docs/src/examples/dde/delay_diffeq.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ supported. For example, we can build a layer with a delay differential equation
like:

```@example dde
using DifferentialEquations, Optimization, SciMLSensitivity,
OptimizationPolyalgorithms
using OrdinaryDiffEq, Optimization, SciMLSensitivity, OptimizationPolyalgorithms,
DelayDiffEq

# Define the same LV equation, but including a delay parameter
function delay_lotka_volterra!(du, u, h, p, t)
Expand All @@ -32,8 +32,7 @@ prob_dde = DDEProblem(delay_lotka_volterra!, u0, h, (0.0, 10.0),

function predict_dde(p)
return Array(solve(prob_dde, MethodOfSteps(Tsit5()),
u0 = u0, p = p, saveat = 0.1,
sensealg = ReverseDiffAdjoint()))
u0 = u0, p = p, saveat = 0.1, sensealg = ReverseDiffAdjoint()))
end

loss_dde(p) = sum(abs2, x - 1 for x in predict_dde(p))
Expand Down
Loading
Loading