-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #59 from acfr/feature/LBDN
Feature/lbdn
- Loading branch information
Showing
33 changed files
with
828 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
```@index | ||
Pages = ["ren_params.md"] | ||
``` | ||
|
||
# Model Parameterisations | ||
|
||
## Lipschitz-Bounded Deep Networks | ||
|
||
```@docs | ||
AbstractLBDNParams | ||
DenseLBDNParams | ||
DirectLBDNParams | ||
ExplicitLBDNParams | ||
``` | ||
|
||
## Recurrent Equilibrium Networks | ||
|
||
```@docs | ||
AbstractRENParams | ||
ContractingRENParams | ||
DirectRENParams | ||
ExplicitRENParams | ||
GeneralRENParams | ||
LipschitzRENParams | ||
PassiveRENParams | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
```@index | ||
Pages = ["models.md"] | ||
``` | ||
|
||
# Model Wrappers | ||
|
||
## Lipschitz-Bounded Deep Networks | ||
|
||
```@docs | ||
AbstractLBDN | ||
DiffLBDN | ||
LBDN | ||
``` | ||
|
||
## Recurrent Equilibrium Networks | ||
|
||
```@docs | ||
AbstractREN | ||
DiffREN | ||
REN | ||
WrapREN | ||
``` |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
[deps] | ||
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" | ||
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" | ||
RobustNeuralNetworks = "a1f18e6b-8af1-433f-a85d-2e1ee636a2b8" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# TODO: This is an old demo from our previous implementation of LBDN. | ||
# No good for the current version! | ||
|
||
# Import packages (iincluding our REN package) | ||
using Flux | ||
using Flux:onehotbatch, @epochs, crossentropy,onecold,throttle, OneHotMatrix | ||
using MLDatasets: MNIST | ||
using RobustNeuralNetworks | ||
using Statistics | ||
|
||
# Get MNIST training and test data | ||
train_x, train_y = MNIST.traindata(Float64) | ||
test_x, test_y = MNIST.testdata(Float64) | ||
|
||
# Reshape as appropriate for training | ||
train_x = Flux.flatten(train_x) | ||
test_x = Flux.flatten(test_x) | ||
|
||
train_y, test_y = onehotbatch(train_y, 0:9), onehotbatch(test_y, 0:9) | ||
|
||
|
||
# Define a model with LBDN | ||
nu = 28*28 # Inputs | ||
nh = [60] # 1 hidden layer of size 60 (has to be in a vector) | ||
ny = 10 # Outputs | ||
γ = 5.0 # Lipschitz bound upper limit (must be a float) | ||
|
||
lbfn = LBFN{Float64}(nu, nh, ny, γ) | ||
m = Chain(lbfn, softmax) | ||
|
||
# Could instead use a fully-connected network (see below) | ||
#m = Chain(Dense(28*28,60, relu), Dense(60, 10), softmax) | ||
|
||
# Define loss function, optimiser, and get params | ||
loss(x, y) = crossentropy(m(x), y) | ||
opt = ADAM(1e-3) | ||
ps = Flux.params(m) | ||
|
||
# Comparison functions | ||
compare(y::OneHotMatrix, y′) = maximum(y′, dims = 1) .== maximum(y .* y′, dims = 1) | ||
accuracy(x, y::OneHotMatrix) = mean(compare(y, m(x))) | ||
|
||
# To check progrress while training | ||
progress = () -> @show(loss(train_x, train_y), accuracy(test_x, test_y) ) # callback to show loss | ||
|
||
# Train model with two different leaning rates | ||
opt = ADAM(1e-3) | ||
@epochs 200 Flux.train!(loss, ps,[(train_x,train_y)], opt, cb = throttle(progress, 10)) | ||
opt = ADAM(1e-4) | ||
@epochs 400 Flux.train!(loss, ps,[(train_x,train_y)], opt, cb = throttle(progress, 10)) | ||
|
||
# Show results | ||
accuracy(train_x,train_y) | ||
accuracy(test_x,test_y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
cd(@__DIR__) | ||
using Pkg | ||
Pkg.activate("../../") | ||
|
||
using Flux | ||
using Random | ||
using RobustNeuralNetworks | ||
using CairoMakie | ||
using Zygote: pullback | ||
|
||
Random.seed!(0) | ||
|
||
# Set up model | ||
nu, ny = 1, 1 | ||
# nh = [10,5,5,15] | ||
nh = fill(90,8) | ||
γ = 1 | ||
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ) | ||
ps = Flux.params(model_ps) | ||
|
||
# Function to estimate | ||
f(x) = sin(x)+(1/N)*sin(N*x) | ||
|
||
# Training data | ||
N = 5 | ||
dx = 0.12 | ||
xs = 0:dx:2π | ||
ys = f.(xs) | ||
T = length(xs) | ||
data = zip(xs,ys) | ||
|
||
# Loss function | ||
function loss(x, y) | ||
model = LBDN(model_ps) | ||
return Flux.mse(model([x]),[y]) | ||
end | ||
|
||
# Callback function to show results while training | ||
function evalcb(α) | ||
model = LBDN(model_ps) | ||
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs)) | ||
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx | ||
@show α fit_error slope | ||
println() | ||
end | ||
|
||
|
||
# Set up training loop | ||
num_epochs = 100 | ||
lrs = [2e-3, 8e-4, 5e-4]#, 1e-4, 5e-5] | ||
for k in eachindex(lrs) | ||
opt = NADAM(lrs[k]) | ||
for i in 1:num_epochs | ||
|
||
# Flux.train!(loss, ps, data, opt) | ||
|
||
for d in data | ||
J, back = pullback(() -> loss(d[1],d[2]), ps) | ||
∇J = back(one(J)) | ||
Flux.update!(opt, ps, ∇J) | ||
end | ||
|
||
(i % 2 == 0) && evalcb(lrs[k]) | ||
end | ||
end | ||
|
||
# Final trained model | ||
model = LBDN(model_ps) | ||
|
||
# Create a figure | ||
f1 = Figure() | ||
ax = Axis(f1[1,1], xlabel="x", ylabel="y") | ||
|
||
ŷ = map(x -> model([x])[1], xs) | ||
lines!(xs, ys, label = "Data") | ||
lines!(xs, ŷ, label = "LBDN") | ||
axislegend(ax) | ||
display(f1) | ||
|
||
# Print out lower-bound on Lipschitz constant | ||
Empirical_Lipschitz = maximum(abs.(diff(model(xs'),dims=2)))/dx | ||
println("Empirical lower Lipschitz bound: ", round(Empirical_Lipschitz; digits=2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
cd(@__DIR__) | ||
using Pkg | ||
Pkg.activate("../../") | ||
|
||
using CairoMakie | ||
using Flux | ||
using Random | ||
using RobustNeuralNetworks | ||
|
||
Random.seed!(0) | ||
|
||
# Set up model | ||
nu, ny = 1, 1 | ||
nh = [10,5,5,15] | ||
γ = 1 | ||
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ) | ||
model = DiffLBDN(model_ps) | ||
ps = Flux.params(model) | ||
|
||
# Function to estimate | ||
f(x) = sin(x)+(1/N)*sin(N*x) | ||
|
||
# Training data | ||
N = 5 | ||
dx = 0.1 | ||
xs = 0:dx:2π | ||
ys = f.(xs) | ||
T = length(xs) | ||
data = zip(xs,ys) | ||
|
||
# Loss function | ||
loss(x,y) = Flux.mse(model([x]),[y]) | ||
|
||
# Callback function to show results while training | ||
function evalcb(α) | ||
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs)) | ||
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx | ||
@show α fit_error slope | ||
println() | ||
end | ||
|
||
# Training loop | ||
num_epochs = [400, 200] | ||
lrs = [2e-4, 5e-5] | ||
for k in eachindex(lrs) | ||
opt = ADAM(lrs[k]) | ||
for i in 1:num_epochs[k] | ||
Flux.train!(loss, ps, data, opt) | ||
(i % 10 == 0) && evalcb(lrs[k]) | ||
end | ||
end | ||
|
||
# Create a figure | ||
f1 = Figure() | ||
ax = Axis(f1[1,1], xlabel="x", ylabel="y") | ||
|
||
ŷ = map(x -> model([x])[1], xs) | ||
lines!(xs, ys, label = "Data") | ||
lines!(xs, ŷ, label = "LBDN") | ||
axislegend(ax) | ||
display(f1) | ||
|
||
# Print out lower-bound on Lipschitz constant | ||
Empirical_Lipschitz = maximum(abs.(diff(model(xs'),dims=2)))/dx | ||
println("Empirical lower Lipschitz bound: ", round(Empirical_Lipschitz; digits=2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.