Skip to content

Commit

Permalink
Merge pull request #21 from acfr/feature/dev-wrappers
Browse files Browse the repository at this point in the history
Feature/dev wrappers
  • Loading branch information
nic-barbara authored Oct 21, 2022
2 parents 4e9122e + 6c1813d commit a93a067
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 43 deletions.
37 changes: 37 additions & 0 deletions src/Base/direct_params.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,40 @@ function Flux.cpu(M::DirectParams{T}) where T
cpu(M.bv), M.ϵ, M.polar_param
)
end

"""
==(ps1::DirectParams, ps2::DirectParams)
Define equality for two objects of type `DirectParams`
"""
function ==(ps1::DirectParams, ps2::DirectParams)

# Compare the options
(ps1.D22_free != ps2.D22_free) && (return false)
(ps1.polar_param != ps2.polar_param) && (return false)

c = fill(false, 11)

# Check implicit parameters
c[1] = ps1.X == ps2.X
c[2] = ps1.Y1 == ps2.Y1

c[3] = ps1.B2 == ps2.B2
c[4] = ps1.D12 == ps2.D12

c[5] = ps1.bx == ps2.bx
c[6] = ps1.bv == ps2.bv

c[7] = ps1.ϵ == ps2.ϵ
c[8] = ps1.polar_param ? (ps1.ρ == ps2.ρ) : true

if !ps1.D22_free
c[9] = ps1.X3 == ps2.X3
c[10] = ps1.Y3 == ps2.Y3
c[11] = ps1.Z3 == ps2.Z3
else
c[9], c[10], c[11] = true, true, true
end

return all(c)
end
22 changes: 22 additions & 0 deletions src/Base/output_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,25 @@ Add CPU compatibility for `OutputLayer` type
function Flux.cpu(layer::OutputLayer{T}) where T
return OutputLayer{T}(cpu(layer.C2), cpu(layer.D21), cpu(layer.D22), cpu(layer.by))
end


"""
==(o1::OutputLayer, o2::OutputLayer)
Define equality for two objects of type `OutputLayer`
"""
function ==(o1::OutputLayer, o2::OutputLayer)

# Compare the options
(o1.D22_trainable != o2.D22_trainable) && (return false)

c = fill(false, 4)

# Check output layer
c[1] = o1.C2 == o2.C2
c[2] = o1.D21 == o2.D21
c[3] = o1.by == o2.by
c[4] = o1.D22 == o2.D22

return all(c)
end
20 changes: 15 additions & 5 deletions src/RecurrentEquilibriumNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ using Flux
using LinearAlgebra
using MatrixEquations: lyapd, plyapd
using Random
using Zygote
using Zygote: @adjoint

import Base.:(==)
import Flux.gpu, Flux.cpu


############ Abstract types ############

"""
Expand All @@ -36,7 +37,11 @@ include("Base/acyclic_ren_solver.jl")
include("Base/direct_params.jl")
include("Base/output_layer.jl")
include("Base/ren.jl")
include("Base/wrapren.jl")

# Wrappers
include("Wrappers/diff_ren.jl")
include("Wrappers/wrap_ren.jl")
include("Wrappers/wrap_ren_2.jl")

# Variations of REN
include("ParameterTypes/utils.jl")
Expand All @@ -52,14 +57,19 @@ include("ParameterTypes/lipschitz_ren.jl")
# Types
export AbstractREN
export AbstractRENParams
export ContractingRENParams

export DirectParams
export ExplicitParams
export GeneralRENParams
export LipschitzRENParams
export OutputLayer
export REN

export ContractingRENParams
export GeneralRENParams
export LipschitzRENParams

export DiffREN
export WrapREN
export WrapREN2

# Functions
export init_states
Expand Down
57 changes: 57 additions & 0 deletions src/Wrappers/diff_ren.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
$(TYPEDEF)
Wrapper for Recurrent Equilibrium Network type which
automatically re-computes explicit parameters every
time the model is called.
Compatible with Flux.jl
"""
mutable struct DiffREN <: AbstractREN
nl
nu::Int
nx::Int
nv::Int
ny::Int
params::AbstractRENParams
T::DataType
end

"""
DiffREN(ps::AbstractRENParams)
Construct DiffREN wrapper from direct parameterisation
"""
function DiffREN(ps::AbstractRENParams{T}) where T
return DiffREN(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, T)
end

"""
Flux.trainable(m::DiffREN)
Define trainable parameters for `DiffREN` type
"""
Flux.trainable(m::DiffREN) = Flux.trainable(m.params)

"""
(m::DiffREN)(xt::VecOrMat, ut::VecOrMat)
Call the REN given internal states xt and inputs ut. If
function arguments are matrices, each column must be a
vector of states or inputs (allows batch simulations).
Computes explicit parameterisation each time. This may
be slow if called many times!
"""
function (m::DiffREN)(xt::VecOrMat, ut::VecOrMat)

explicit = direct_to_explicit(m.params)

b = explicit.C1 * xt + explicit.D12 * ut .+ explicit.bv
wt = tril_eq_layer(m.nl, explicit.D11, b)
xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ explicit.bx
yt = explicit.C2 * xt + explicit.D21 * wt + explicit.D22 * ut .+ explicit.by

return xt1, yt

end
9 changes: 5 additions & 4 deletions src/Base/wrapren.jl → src/Wrappers/wrap_ren.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
$(TYPEDEF)
Wrapper for Recurrent Equilibrium Network type combining
direct parameters and explicit model into one type
direct parameters and explicit model into one type.
Requires user to manually update explicit params when
direct params are changed. Not compatible with Flux.jl
"""
mutable struct WrapREN <: AbstractREN
nl
Expand Down Expand Up @@ -40,6 +43,4 @@ end
Define trainable parameters for `WrapREN` type
"""
Flux.trainable(m::WrapREN) = [
Flux.trainable(m.params.direct)..., Flux.trainable(m.params.output)...
]
Flux.trainable(m::WrapREN) = Flux.params(m.params)
67 changes: 67 additions & 0 deletions src/Wrappers/wrap_ren_2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""
$(TYPEDEF)
Wrapper for Recurrent Equilibrium Network type which
automatically re-computes explicit parameters whenever
the direct parameters are edited.
Not compatible with Flux.jl
"""
mutable struct WrapREN2 <: AbstractREN
nl
nu::Int
nx::Int
nv::Int
ny::Int
params::AbstractRENParams
explicit::ExplicitParams
old_params::AbstractRENParams
T::DataType
end

"""
WrapREN2(ps::AbstractRENParams)
Construct WrapREN2 wrapper from direct parameterisation
"""
function WrapREN2(ps::AbstractRENParams{T}) where T
explicit = direct_to_explicit(ps)
old_ps = deepcopy(ps)
return WrapREN2(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, explicit, old_ps, T)
end

"""
Flux.trainable(m::WrapREN2)
Define trainable parameters for `WrapREN2` type
"""
Flux.trainable(m::WrapREN2) = Flux.trainable(m.params)

"""
(m::WrapREN2)(xt::VecOrMat, ut::VecOrMat)
Call the REN given internal states xt and inputs ut. If
function arguments are matrices, each column must be a
vector of states or inputs (allows batch simulations).
Updates the explicit parameterisation if direct parameters
have been updated.
"""
function (m::WrapREN2)(xt::VecOrMat, ut::VecOrMat)

# Compute explicit parameterisation
if !( (m.params.direct == m.old_params.direct) &&
(m.params.output == m.old_params.output) )
m.explicit = direct_to_explicit(m.params)
m.old_params = deepcopy(m.params)
end

# Compute update
b = m.explicit.C1 * xt + m.explicit.D12 * ut .+ m.explicit.bv
wt = tril_eq_layer(m.nl, m.explicit.D11, b)
xt1 = m.explicit.A * xt + m.explicit.B1 * wt + m.explicit.B2 * ut .+ m.explicit.bx
yt = m.explicit.C2 * xt + m.explicit.D21 * wt + m.explicit.D22 * ut .+ m.explicit.by

return xt1, yt

end
File renamed without changes.
File renamed without changes.
File renamed without changes.
40 changes: 40 additions & 0 deletions test/Wrappers/wrap_rens.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using LinearAlgebra
using Random
using RecurrentEquilibriumNetworks
using Test

"""
Test REN wrapper with General REN params
"""
batches = 20
nu, nx, nv, ny = 4, 5, 10, 2

Q = Matrix{Float64}(-I(ny))
R = 0.1^2 * Matrix{Float64}(I(nu))
S = zeros(Float64, nu, ny)

ren_ps = GeneralRENParams{Float64}(nu, nx, nv, ny, Q, S, R)
ren1 = WrapREN(ren_ps)
ren2 = WrapREN2(deepcopy(ren_ps))

x0 = init_states(ren1, batches)
u0 = randn(nu, batches)

# Update the model after changing a parameter
old_B2 = deepcopy(ren1.explicit.B2)
ren1.params.direct.B2 .*= rand(size(ren1.params.direct.B2)...)

x1, y1 = ren1(x0, u0)
update_explicit!(ren1)

new_B2 = deepcopy(ren1.explicit.B2)
@test old_B2 != new_B2

# Test auto-update
old_B2 = deepcopy(ren2.explicit.B2)
ren2.params.direct.B2 .*= rand(size(ren2.params.direct.B2)...)

x1, y1 = ren2(x0, u0)

new_B2 = deepcopy(ren2.explicit.B2)
@test old_B2 != new_B2
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ using Test
@testset "RecurrentEquilibriumNetworks.jl" begin

# Test a basic example from the README
include("wrap_ren.jl")
include("Wrappers/wrap_rens.jl")

# Test for desired behaviour
include("contraction.jl")
include("general_behavioural_constrains.jl")
include("lipschitz_bound.jl")
include("ParameterTypes/contraction.jl")
include("ParameterTypes/general_behavioural_constrains.jl")
include("ParameterTypes/lipschitz_bound.jl")

end
30 changes: 0 additions & 30 deletions test/wrap_ren.jl

This file was deleted.

0 comments on commit a93a067

Please sign in to comment.