Skip to content

Commit

Permalink
Merge pull request #59 from acfr/feature/LBDN
Browse files Browse the repository at this point in the history
Feature/lbdn
  • Loading branch information
nic-barbara authored May 10, 2023
2 parents e973a5e + 0bce5d3 commit 751d644
Show file tree
Hide file tree
Showing 33 changed files with 828 additions and 399 deletions.
5 changes: 2 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@ makedocs(
"PDE Observer" => "examples/pde_obsv.md",
],
"Library" => Any[
"Lipschitz-Bounded Deep Networks" => "lib/lbdn.md",
"Recurrent Equilibrium Networks" => "lib/ren.md",
"REN Parameterisations" => "lib/ren_params.md",
"Model Wrappers" => "lib/models.md",
"Model Parameterisations" => "lib/model_params.md",
"Functions" => "lib/functions.md",
],
"API" => "api.md"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Depth = 1
## Library

```@contents
Pages = ["lib/lbdn.md", "lib/ren.md", "lib/ren_params.md", "lib/functions.md"]
Pages = ["lib/models.md", "lib/model_params.md", "lib/functions.md"]
Depth = 1
```

Expand Down
10 changes: 5 additions & 5 deletions docs/src/introduction/layout.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Each of these parameter types has the following collection of attributes:

- Model sizes `nu`, `nx`, `nv`, `ny` defining the number of inputs, states, neurons, and outputs (respectively).

- An instance of [`DirectParams`](@ref) containing the direct parameters of the REN, including all **trainable** parameters.
- An instance of [`DirectRENParams`](@ref) containing the direct parameters of the REN, including all **trainable** parameters.

- Other attributes used to define how the direct parameterisation should be converted to the implicit model. These parameters encode the user-tunable behavioural constraints. Eg: $\gamma$ for a Lipschitz-bounded REN.

Expand All @@ -77,9 +77,9 @@ An *explicit* REN model must be created to call and use the network for computat

- A static nonlinearity `nl` and model sizes `nu`, `nx`, `nv`, `ny` (same as `AbstractRENParams`.

- An instance of `ExplicitParams` containing all REN parameters in explicit form for model evaluation (see the [`ExplicitParams`](@ref) docs for more detail).
- An instance of `ExplicitRENParams` containing all REN parameters in explicit form for model evaluation (see the [`ExplicitRENParams`](@ref) docs for more detail).

Each subtype of `AbstractRENParams` has a method [`direct_to_explicit`](@ref) associated with it that converts the `DirectParams` struct to an instance of `ExplicitParams` satisfying the specified behavioural constraints.
Each subtype of `AbstractRENParams` has a method [`direct_to_explicit`](@ref) associated with it that converts the `DirectRENParams` struct to an instance of `ExplicitRENParams` satisfying the specified behavioural constraints.


#### REN Wrappers
Expand All @@ -91,7 +91,7 @@ There are three explicit REN wrappers currently implemented in this package. Eac
!!! tip "REN is recommended"
We strongly recommend using `REN` to train your models with `Flux.jl`. It is the most efficient subtype of `AbstractREN` that is compatible with automatic differentiation.

- [`WrapREN`](@ref) includes both the `DirectParams` and `ExplicitParams` as part of the REN wrapper. When any of the direct parameters are changed, the explicit model can be updated by calling [`update_explicit!`](@ref). This can be useful when not using automatic differentiation to train the model. For example:
- [`WrapREN`](@ref) includes both the `DirectRENParams` and `ExplicitRENParams` as part of the REN wrapper. When any of the direct parameters are changed, the explicit model can be updated by calling [`update_explicit!`](@ref). This can be useful when not using automatic differentiation to train the model. For example:

```julia
using RobustNeuralNetworks
Expand All @@ -110,7 +110,7 @@ for k in 1:num_training_epochs
!!! warning "WrapREN incompatible with Flux.jl"
Since the explicit parameters are stored in an instance of `WrapREN`, changing them with `update_explicit!` directly mutates the model. This will cause errors if the model is to be trained with [`Flux.jl`](http://fluxml.ai/Flux.jl/stable/). Use [`REN`](@ref) or [`DiffREN`](@ref) to avoid this issue.

- [`DiffREN`](@ref) also includes `DirectParams`, but never stores the `ExplicitParams`. Instead, the explicit parameters are computed every time the model is evaluated. This is slow, but does not require creating a new object when the parameters are updated, and is still compatible with `Flux.jl`. For example:
- [`DiffREN`](@ref) also includes `DirectRENParams`, but never stores the `ExplicitRENParams`. Instead, the explicit parameters are computed every time the model is evaluated. This is slow, but does not require creating a new object when the parameters are updated, and is still compatible with `Flux.jl`. For example:

```julia
using Flux
Expand Down
10 changes: 0 additions & 10 deletions docs/src/lib/lbdn.md

This file was deleted.

26 changes: 26 additions & 0 deletions docs/src/lib/model_params.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
```@index
Pages = ["ren_params.md"]
```

# Model Parameterisations

## Lipschitz-Bounded Deep Networks

```@docs
AbstractLBDNParams
DenseLBDNParams
DirectLBDNParams
ExplicitLBDNParams
```

## Recurrent Equilibrium Networks

```@docs
AbstractRENParams
ContractingRENParams
DirectRENParams
ExplicitRENParams
GeneralRENParams
LipschitzRENParams
PassiveRENParams
```
22 changes: 22 additions & 0 deletions docs/src/lib/models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
```@index
Pages = ["models.md"]
```

# Model Wrappers

## Lipschitz-Bounded Deep Networks

```@docs
AbstractLBDN
DiffLBDN
LBDN
```

## Recurrent Equilibrium Networks

```@docs
AbstractREN
DiffREN
REN
WrapREN
```
12 changes: 0 additions & 12 deletions docs/src/lib/ren.md

This file was deleted.

15 changes: 0 additions & 15 deletions docs/src/lib/ren_params.md

This file was deleted.

5 changes: 5 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
RobustNeuralNetworks = "a1f18e6b-8af1-433f-a85d-2e1ee636a2b8"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
54 changes: 54 additions & 0 deletions examples/src/lbdn_mnist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# TODO: This is an old demo from our previous implementation of LBDN.
# No good for the current version!

# Import packages (iincluding our REN package)
using Flux
using Flux:onehotbatch, @epochs, crossentropy,onecold,throttle, OneHotMatrix
using MLDatasets: MNIST
using RobustNeuralNetworks
using Statistics

# Get MNIST training and test data
train_x, train_y = MNIST.traindata(Float64)
test_x, test_y = MNIST.testdata(Float64)

# Reshape as appropriate for training
train_x = Flux.flatten(train_x)
test_x = Flux.flatten(test_x)

train_y, test_y = onehotbatch(train_y, 0:9), onehotbatch(test_y, 0:9)


# Define a model with LBDN
nu = 28*28 # Inputs
nh = [60] # 1 hidden layer of size 60 (has to be in a vector)
ny = 10 # Outputs
γ = 5.0 # Lipschitz bound upper limit (must be a float)

lbfn = LBFN{Float64}(nu, nh, ny, γ)
m = Chain(lbfn, softmax)

# Could instead use a fully-connected network (see below)
#m = Chain(Dense(28*28,60, relu), Dense(60, 10), softmax)

# Define loss function, optimiser, and get params
loss(x, y) = crossentropy(m(x), y)
opt = ADAM(1e-3)
ps = Flux.params(m)

# Comparison functions
compare(y::OneHotMatrix, y′) = maximum(y′, dims = 1) .== maximum(y .* y′, dims = 1)
accuracy(x, y::OneHotMatrix) = mean(compare(y, m(x)))

# To check progrress while training
progress = () -> @show(loss(train_x, train_y), accuracy(test_x, test_y) ) # callback to show loss

# Train model with two different leaning rates
opt = ADAM(1e-3)
@epochs 200 Flux.train!(loss, ps,[(train_x,train_y)], opt, cb = throttle(progress, 10))
opt = ADAM(1e-4)
@epochs 400 Flux.train!(loss, ps,[(train_x,train_y)], opt, cb = throttle(progress, 10))

# Show results
accuracy(train_x,train_y)
accuracy(test_x,test_y)
82 changes: 82 additions & 0 deletions examples/src/test-lbdn/lbdn_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
cd(@__DIR__)
using Pkg
Pkg.activate("../../")

using Flux
using Random
using RobustNeuralNetworks
using CairoMakie
using Zygote: pullback

Random.seed!(0)

# Set up model
nu, ny = 1, 1
# nh = [10,5,5,15]
nh = fill(90,8)
γ = 1
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ)
ps = Flux.params(model_ps)

# Function to estimate
f(x) = sin(x)+(1/N)*sin(N*x)

# Training data
N = 5
dx = 0.12
xs = 0:dx:2π
ys = f.(xs)
T = length(xs)
data = zip(xs,ys)

# Loss function
function loss(x, y)
model = LBDN(model_ps)
return Flux.mse(model([x]),[y])
end

# Callback function to show results while training
function evalcb(α)
model = LBDN(model_ps)
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs))
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx
@show α fit_error slope
println()
end


# Set up training loop
num_epochs = 100
lrs = [2e-3, 8e-4, 5e-4]#, 1e-4, 5e-5]
for k in eachindex(lrs)
opt = NADAM(lrs[k])
for i in 1:num_epochs

# Flux.train!(loss, ps, data, opt)

for d in data
J, back = pullback(() -> loss(d[1],d[2]), ps)
∇J = back(one(J))
Flux.update!(opt, ps, ∇J)
end

(i % 2 == 0) && evalcb(lrs[k])
end
end

# Final trained model
model = LBDN(model_ps)

# Create a figure
f1 = Figure()
ax = Axis(f1[1,1], xlabel="x", ylabel="y")

= map(x -> model([x])[1], xs)
lines!(xs, ys, label = "Data")
lines!(xs, ŷ, label = "LBDN")
axislegend(ax)
display(f1)

# Print out lower-bound on Lipschitz constant
Empirical_Lipschitz = maximum(abs.(diff(model(xs'),dims=2)))/dx
println("Empirical lower Lipschitz bound: ", round(Empirical_Lipschitz; digits=2))
65 changes: 65 additions & 0 deletions examples/src/test-lbdn/lbdn_test_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
cd(@__DIR__)
using Pkg
Pkg.activate("../../")

using CairoMakie
using Flux
using Random
using RobustNeuralNetworks

Random.seed!(0)

# Set up model
nu, ny = 1, 1
nh = [10,5,5,15]
γ = 1
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ)
model = DiffLBDN(model_ps)
ps = Flux.params(model)

# Function to estimate
f(x) = sin(x)+(1/N)*sin(N*x)

# Training data
N = 5
dx = 0.1
xs = 0:dx:2π
ys = f.(xs)
T = length(xs)
data = zip(xs,ys)

# Loss function
loss(x,y) = Flux.mse(model([x]),[y])

# Callback function to show results while training
function evalcb(α)
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs))
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx
@show α fit_error slope
println()
end

# Training loop
num_epochs = [400, 200]
lrs = [2e-4, 5e-5]
for k in eachindex(lrs)
opt = ADAM(lrs[k])
for i in 1:num_epochs[k]
Flux.train!(loss, ps, data, opt)
(i % 10 == 0) && evalcb(lrs[k])
end
end

# Create a figure
f1 = Figure()
ax = Axis(f1[1,1], xlabel="x", ylabel="y")

= map(x -> model([x])[1], xs)
lines!(xs, ys, label = "Data")
lines!(xs, ŷ, label = "LBDN")
axislegend(ax)
display(f1)

# Print out lower-bound on Lipschitz constant
Empirical_Lipschitz = maximum(abs.(diff(model(xs'),dims=2)))/dx
println("Empirical lower Lipschitz bound: ", round(Empirical_Lipschitz; digits=2))
6 changes: 3 additions & 3 deletions notes/nic_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ This document is a work-in-progress. Use it to start our documentation of the `R
- Currently no method to construct a REN without specifying type. Seems good to force it
- Treatment of `D22`:
- `D22` can be included directly as a trainable parameter in the `OutputLayer`, but is **NOT** by default. Have to set `D22_trainable = true` on construction
- `D22` can be parameterised by free parameters `(X3,Y3,Z3)` by setting `D22_free` in `DirectParams`
- `D22` can be parameterised by free parameters `(X3,Y3,Z3)` by setting `D22_free` in `DirectRENParams`
- Added \alpha_bar to `ContractingRENParams` (as an example). Will use this to set contraction rate in explicit params construction

### Bugs:
- Cholesky initialisation of `DirectParams`
- Cholesky initialisation of `DirectRENParams`

### Functionality
- How to include different initialisation methods?
Expand All @@ -29,6 +29,6 @@ This document is a work-in-progress. Use it to start our documentation of the `R
### Conversion questions:
- `D22` terms in `output.jl` initialised as 0. Good idea in general?
- `sample_ff_ren` had `randn(2nx + nv, 2nx + nv) / sqrt(2nx + nv)`, not `.../ sqrt(2*(2nx + nv))`. Why?
- Why have `bx_scale` and `bv_scale` in constructor for `DirectParams`?
- Why have `bx_scale` and `bv_scale` in constructor for `DirectRENParams`?
- Constructor for `implicit_ff_cell` has zeros for some params, random for others. Why?
- Construction of S1 (now `Y1`) in `ffREN.jl` line 130/136 divides by 2 an extra time. Not necessary, right? Does this change anything?
Loading

0 comments on commit 751d644

Please sign in to comment.