Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/flux interface #62

Merged
merged 10 commits into from
May 12, 2023
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@ authors = ["Nicholas H. Barbara", "Max Revay", "Ruigang Wang", "Jing Cheng", "Je
version = "0.1.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3, 4"
Flux = "0.13"
MatrixEquations = "2"
Zygote = "0.6"
Expand Down
33 changes: 11 additions & 22 deletions examples/src/test-lbdn/lbdn_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,55 +12,44 @@ Random.seed!(0)

# Set up model
nu, ny = 1, 1
# nh = [10,5,5,15]
nh = fill(90,8)
nh = [10,5,5,15]
γ = 1
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ)
ps = Flux.params(model_ps)

# Function to estimate
f(x) = sin(x)+(1/N)*sin(N*x)

# Training data
N = 5
dx = 0.12
dx = 0.1
xs = 0:dx:2π
ys = f.(xs)
T = length(xs)
data = zip(xs,ys)

# Loss function
function loss(x, y)
function loss(model_ps, x, y)
model = LBDN(model_ps)
return Flux.mse(model([x]),[y])
end

# Callback function to show results while training
function evalcb(α)
function evalcb(model_ps, α)
model = LBDN(model_ps)
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs))
fit_error = sqrt(sum(loss.((model_ps,), xs, ys)) / length(xs))
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx
@show α fit_error slope
println()
end


# Set up training loop
num_epochs = 100
lrs = [2e-3, 8e-4, 5e-4]#, 1e-4, 5e-5]
num_epochs = [400, 200]
lrs = [2e-4, 5e-5]
for k in eachindex(lrs)
opt = NADAM(lrs[k])
for i in 1:num_epochs

# Flux.train!(loss, ps, data, opt)

for d in data
J, back = pullback(() -> loss(d[1],d[2]), ps)
∇J = back(one(J))
Flux.update!(opt, ps, ∇J)
end

(i % 2 == 0) && evalcb(lrs[k])
opt_state = Flux.setup(Adam(lrs[k]), model_ps)
for i in 1:num_epochs[k]
Flux.train!(loss, model_ps, data, opt_state)
(i % 10 == 0) && evalcb(model_ps, lrs[k])
end
end

Expand Down
16 changes: 8 additions & 8 deletions examples/src/test-lbdn/lbdn_test_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ nh = [10,5,5,15]
γ = 1
model_ps = DenseLBDNParams{Float64}(nu, nh, ny, γ)
model = DiffLBDN(model_ps)
ps = Flux.params(model)

# Function to estimate
f(x) = sin(x)+(1/N)*sin(N*x)
# f(x) = (x < π/2 || (x > π && x < 3π/2)) ? 1 : 0

# Training data
N = 5
Expand All @@ -29,24 +29,24 @@ T = length(xs)
data = zip(xs,ys)

# Loss function
loss(x,y) = Flux.mse(model([x]),[y])
loss(model,x,y) = Flux.mse(model([x]),[y])

# Callback function to show results while training
function evalcb(α)
fit_error = sqrt(sum(loss.(xs, ys)) / length(xs))
function evalcb(model, α)
fit_error = sqrt(sum(loss.((model,), xs, ys)) / length(xs))
slope = maximum(abs.(diff(model(xs'),dims=2)))/dx
@show α fit_error slope
println()
end

# Training loop
# Training loop (could be improved with batches...)
num_epochs = [400, 200]
lrs = [2e-4, 5e-5]
for k in eachindex(lrs)
opt = ADAM(lrs[k])
opt_state = Flux.setup(Adam(lrs[k]), model)
for i in 1:num_epochs[k]
Flux.train!(loss, ps, data, opt)
(i % 10 == 0) && evalcb(lrs[k])
Flux.train!(loss, model, data, opt_state)
(i % 10 == 0) && evalcb(model, lrs[k])
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/Base/lbdn_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function DirectLBDNParams{T}(

end

Flux.trainable(m::DirectLBDNParams) = (m.XY, m.α, m.d, m.b)
Flux.@functor DirectLBDNParams (XY, α, d, b)


# TODO: Should add compatibility for layer-wise options
Expand Down
96 changes: 49 additions & 47 deletions src/Base/ren_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,39 +29,39 @@ b_x \\ b_v \\ b_y
See [Revay et al. (2021)](https://arxiv.org/abs/2104.05942) for more details on explicit parameterisations of REN.
"""
mutable struct ExplicitRENParams{T}
A::Matrix{T}
B1::Matrix{T}
B2::Matrix{T}
C1::Matrix{T}
C2::Matrix{T}
D11::Matrix{T}
D12::Matrix{T}
D21::Matrix{T}
D22::Matrix{T}
bx::Vector{T}
bv::Vector{T}
by::Vector{T}
A ::AbstractMatrix{T}
B1 ::AbstractMatrix{T}
B2 ::AbstractMatrix{T}
C1 ::AbstractMatrix{T}
C2 ::AbstractMatrix{T}
D11::AbstractMatrix{T}
D12::AbstractMatrix{T}
D21::AbstractMatrix{T}
D22::AbstractMatrix{T}
bx ::AbstractVector{T}
bv ::AbstractVector{T}
by ::AbstractVector{T}
end

mutable struct DirectRENParams{T}
ρ::Union{Vector{T},CuVector{T}} # used in polar param
X::Union{Matrix{T},CuMatrix{T}}
Y1::Union{Matrix{T},CuMatrix{T}}
X3::Union{Matrix{T},CuMatrix{T}}
Y3::Union{Matrix{T},CuMatrix{T}}
Z3::Union{Matrix{T},CuMatrix{T}}
B2::Union{Matrix{T},CuMatrix{T}}
C2::Union{Matrix{T},CuMatrix{T}}
D12::Union{Matrix{T},CuMatrix{T}}
D21::Union{Matrix{T},CuMatrix{T}}
D22::Union{Matrix{T},CuMatrix{T}}
bx::Union{Vector{T},CuVector{T}}
bv::Union{Vector{T},CuVector{T}}
by::Union{Vector{T},CuVector{T}}
ϵ::T
polar_param::Bool # Whether or not to use polar param
D22_free::Bool # Is D22 free or parameterised by (X3,Y3,Z3)?
D22_zero::Bool # Option to remove feedthrough.
X ::AbstractMatrix{T}
Y1 ::AbstractMatrix{T}
X3 ::AbstractMatrix{T}
Y3 ::AbstractMatrix{T}
Z3 ::AbstractMatrix{T}
B2 ::AbstractMatrix{T}
C2 ::AbstractMatrix{T}
D12::AbstractMatrix{T}
D21::AbstractMatrix{T}
D22::AbstractMatrix{T}
bx ::AbstractVector{T}
bv ::AbstractVector{T}
by ::AbstractVector{T}
ϵ ::T
ρ ::AbstractVector{T} # Used in polar param (if specified)
polar_param::Bool # Whether or not to use polar parameterisation
D22_free ::Bool # Is D22 free or parameterised by (X3,Y3,Z3)?
D22_zero ::Bool # Option to remove feedthrough.
end

"""
Expand Down Expand Up @@ -157,7 +157,7 @@ function DirectRENParams{T}(
end

# Polar parameter
ρ = [norm(X)]
ρ = polar_param ? [norm(X)] : zeros(T,0)

# Free parameter for E
Y1 = glorot_normal(nx, nx; T=T, rng=rng)
Expand Down Expand Up @@ -185,35 +185,37 @@ function DirectRENParams{T}(
by = glorot_normal(ny; rng=rng)

return DirectRENParams(
ρ ,X,
X,
Y1, X3, Y3, Z3,
B2, C2, D12, D21, D22,
bx, bv, by, T(ϵ),
bx, bv, by, T(ϵ), ρ,
polar_param, D22_free, D22_zero
)
end

function Flux.trainable(L::DirectRENParams)
Flux.@functor DirectRENParams

# Different cases for D22 free/zero
if L.D22_free
if L.D22_zero
ps = [L.ρ, L.X, L.Y1, L.B2, L.C2,
L.D12, L.D21, L.bx, L.bv, L.by]
function Flux.trainable(m::DirectRENParams)

# Field names of trainable params, exclude ρ if needed
if m.D22_free
if m.D22_zero
fs = [:X, :Y1, :B2, :C2, :D12, :D21, :bx, :bv, :by, :ρ]
else
ps = [L.ρ, L.X, L.Y1, L.B2, L.C2,
L.D12, L.D21, L.D22, L.bx, L.bv, L.by]
fs = [:X, :Y1, :B2, :C2, :D12, :D21, :D22, :bx, :bv, :by, :ρ]
end
else
ps = [L.ρ, L.X, L.Y1, L.X3, L.Y3, L.Z3, L.B2,
L.C2, L.D12, L.D21, L.bx, L.bv, L.by]
fs = [:X, :Y1, :X3, :Y3, :Z3, :B2, :C2, :D12, :D21, :bx, :bv, :by, :ρ]
end
!(m.polar_param) && pop!(fs)

# Don't need ρ if not polar param
!(L.polar_param) && popfirst!(ps)
# Get params, ignore empty ones (eg: when nx=0)
ps = [getproperty(m, f) for f in fs]
indx = length.(ps) .!= 0
ps, fs = ps[indx], fs[indx]

# Removes empty params, useful when nx=0
return filter(p -> length(p) !=0, ps)
# Flux.trainable() must return a NamedTuple
return NamedTuple{tuple(fs...)}(ps)
end

function Flux.gpu(M::DirectRENParams{T}) where T
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterTypes/contracting_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function ContractingRENParams(

end

Flux.trainable(m::ContractingRENParams) = Flux.trainable(m.direct)
Flux.@functor ContractingRENParams (direct, )

function Flux.gpu(m::ContractingRENParams{T}) where T
# TODO: Test and complete this
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterTypes/dense_lbdn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function DenseLBDNParams{T}(

end

Flux.trainable(m::DenseLBDNParams) = Flux.trainable(m.direct)
Flux.@functor DenseLBDNParams (direct, )

function direct_to_explicit(ps::DenseLBDNParams{T}) where T

Expand Down
2 changes: 1 addition & 1 deletion src/ParameterTypes/general_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ function GeneralRENParams{T}(

end

Flux.trainable(m::GeneralRENParams) = Flux.trainable(m.direct)
Flux.@functor GeneralRENParams (direct, )

function Flux.gpu(m::GeneralRENParams{T}) where T
# TODO: Test and complete this
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterTypes/lipschitz_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function LipschitzRENParams{T}(

end

Flux.trainable(m::LipschitzRENParams) = Flux.trainable(m.direct)
Flux.@functor LipschitzRENParams (direct,)

function Flux.gpu(m::LipschitzRENParams{T}) where T
# TODO: Test and complete this
Expand Down
8 changes: 1 addition & 7 deletions src/ParameterTypes/passive_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ function PassiveRENParams{T}(

end

function passive_trainable(L::DirectRENParams)
ps = [L.ρ, L.X, L.Y1, L.X3, L.Y3, L.Z3, L.B2, L.C2, L.D12, L.D21, L.bx, L.bv, L.by]
!(L.polar_param) && popfirst!(ps)
return filter(p -> length(p) !=0, ps)
end

Flux.trainable(m::PassiveRENParams) = passive_trainable(m.direct)
Flux.@functor PassiveRENParams (direct, )

function Flux.gpu(m::PassiveRENParams{T}) where T
# TODO: Test and complete this
Expand Down
1 change: 0 additions & 1 deletion src/RobustNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module RobustNeuralNetworks

############ Package dependencies ############

using CUDA: CuVector, CuMatrix
using Flux
using LinearAlgebra
using MatrixEquations: lyapd, plyapd
Expand Down
2 changes: 1 addition & 1 deletion src/Wrappers/LBDN/diff_lbdn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ function (m::DiffLBDN)(u::AbstractVecOrMat)
return m(u, explicit)
end

Flux.trainable(m::DiffLBDN) = Flux.trainable(m.params)
Flux.@functor DiffLBDN (params, )
6 changes: 3 additions & 3 deletions src/Wrappers/REN/diff_ren.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
mutable struct DiffREN{T} <: AbstractREN{T}
nl
nl::Function
nu::Int
nx::Int
nv::Int
ny::Int
params::AbstractRENParams
params::AbstractRENParams{T}
end

"""
Expand All @@ -22,7 +22,7 @@ function DiffREN(ps::AbstractRENParams{T}) where T
return DiffREN{T}(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps)
end

Flux.trainable(m::DiffREN) = Flux.trainable(m.params)
Flux.@functor DiffREN (params, )

function (m::DiffREN)(xt::AbstractVecOrMat, ut::AbstractVecOrMat)
explicit = direct_to_explicit(m.params)
Expand Down
2 changes: 1 addition & 1 deletion src/Wrappers/REN/ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mutable struct REN{T} <: AbstractREN{T}
nx::Int
nv::Int
ny::Int
explicit::ExplicitRENParams
explicit::ExplicitRENParams{T}
end

"""
Expand Down
8 changes: 4 additions & 4 deletions src/Wrappers/REN/wrap_ren.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
mutable struct WrapREN{T} <: AbstractREN{T}
nl
nl::Function
nu::Int
nx::Int
nv::Int
ny::Int
explicit::ExplicitRENParams
params::AbstractRENParams
explicit::ExplicitRENParams{T}
params::AbstractRENParams{T}
end

"""
Expand Down Expand Up @@ -74,4 +74,4 @@ function update_explicit!(m::WrapREN)
return nothing
end

Flux.trainable(m::WrapREN) = Flux.params(m.params)
Flux.@functor WrapREN (params, )
Loading