From dac8e7ef1c72d6a516322c76fefaf4b1f9b94ddc Mon Sep 17 00:00:00 2001 From: nic-barbara Date: Fri, 21 Oct 2022 17:00:46 +1100 Subject: [PATCH 1/4] Added wrapper for REN which auto-updates explicit model when required. --- src/Base/direct_params.jl | 37 ++++++++++ src/Base/output_layer.jl | 22 ++++++ src/{Base/wrapren.jl => Wrappers/wrap_ren.jl} | 9 +-- src/Wrappers/wrap_ren_2.jl | 67 +++++++++++++++++++ 4 files changed, 131 insertions(+), 4 deletions(-) rename src/{Base/wrapren.jl => Wrappers/wrap_ren.jl} (79%) create mode 100644 src/Wrappers/wrap_ren_2.jl diff --git a/src/Base/direct_params.jl b/src/Base/direct_params.jl index 0855006..cc6590c 100644 --- a/src/Base/direct_params.jl +++ b/src/Base/direct_params.jl @@ -154,3 +154,40 @@ function Flux.cpu(M::DirectParams{T}) where T cpu(M.bv), M.ϵ, M.polar_param ) end + +""" + ==(ps1::DirectParams, ps2::DirectParams) + +Define equality for two objects of type `DirectParams` +""" +function ==(ps1::DirectParams, ps2::DirectParams) + + # Compare the options + (ps1.D22_free != ps2.D22_free) && (return false) + (ps1.polar_param != ps2.polar_param) && (return false) + + c = fill(false, 11) + + # Check implicit parameters + c[1] = ps1.X == ps2.X + c[2] = ps1.Y1 == ps2.Y1 + + c[3] = ps1.B2 == ps2.B2 + c[4] = ps1.D12 == ps2.D12 + + c[5] = ps1.bx == ps2.bx + c[6] = ps1.bv == ps2.bv + + c[7] = ps1.ϵ == ps2.ϵ + c[8] = ps1.polar_param ? (ps1.ρ == ps2.ρ) : true + + if !ps1.D22_free + c[9] = ps1.X3 == ps2.X3 + c[10] = ps1.Y3 == ps2.Y3 + c[11] = ps1.Z3 == ps2.Z3 + else + c[9], c[10], c[11] = true, true, true + end + + return all(c) +end diff --git a/src/Base/output_layer.jl b/src/Base/output_layer.jl index 841891f..9f2dd09 100644 --- a/src/Base/output_layer.jl +++ b/src/Base/output_layer.jl @@ -61,3 +61,25 @@ Add CPU compatibility for `OutputLayer` type function Flux.cpu(layer::OutputLayer{T}) where T return OutputLayer{T}(cpu(layer.C2), cpu(layer.D21), cpu(layer.D22), cpu(layer.by)) end + + +""" + ==(o1::OutputLayer, o2::OutputLayer) + +Define equality for two objects of type `OutputLayer` +""" +function ==(o1::OutputLayer, o2::OutputLayer) + + # Compare the options + (o1.D22_trainable != o2.D22_trainable) && (return false) + + c = fill(false, 4) + + # Check output layer + c[1] = o1.C2 == o2.C2 + c[2] = o1.D21 == o2.D21 + c[3] = o1.by == o2.by + c[4] = o1.D22 == o2.D22 + + return all(c) +end diff --git a/src/Base/wrapren.jl b/src/Wrappers/wrap_ren.jl similarity index 79% rename from src/Base/wrapren.jl rename to src/Wrappers/wrap_ren.jl index c065c9c..f6d4834 100644 --- a/src/Base/wrapren.jl +++ b/src/Wrappers/wrap_ren.jl @@ -2,7 +2,10 @@ $(TYPEDEF) Wrapper for Recurrent Equilibrium Network type combining -direct parameters and explicit model into one type +direct parameters and explicit model into one type. + +Requires user to manually update explicit params when +direct params are changed. Not compatible with Flux.jl """ mutable struct WrapREN <: AbstractREN nl @@ -40,6 +43,4 @@ end Define trainable parameters for `WrapREN` type """ -Flux.trainable(m::WrapREN) = [ - Flux.trainable(m.params.direct)..., Flux.trainable(m.params.output)... -] +Flux.trainable(m::WrapREN) = Flux.params(m.params) diff --git a/src/Wrappers/wrap_ren_2.jl b/src/Wrappers/wrap_ren_2.jl new file mode 100644 index 0000000..f0d2be7 --- /dev/null +++ b/src/Wrappers/wrap_ren_2.jl @@ -0,0 +1,67 @@ +""" +$(TYPEDEF) + +Wrapper for Recurrent Equilibrium Network type which +automatically re-computes explicit parameters whenever +the direct parameters are edited. + +Not compatible with Flux.jl +""" +mutable struct WrapREN2 <: AbstractREN + nl + nu::Int + nx::Int + nv::Int + ny::Int + params::AbstractRENParams + explicit::ExplicitParams + old_params::AbstractRENParams + T::DataType +end + +""" + WrapREN2(ps::AbstractRENParams) + +Construct WrapREN2 wrapper from direct parameterisation +""" +function WrapREN2(ps::AbstractRENParams{T}) where T + explicit = direct_to_explicit(ps) + old_ps = deepcopy(ps) + return WrapREN2(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, explicit, old_ps, T) +end + +""" + Flux.trainable(m::WrapREN2) + +Define trainable parameters for `WrapREN2` type +""" +Flux.trainable(m::WrapREN2) = Flux.trainable(m.params) + +""" + (m::WrapREN2)(xt::VecOrMat, ut::VecOrMat) + +Call the REN given internal states xt and inputs ut. If +function arguments are matrices, each column must be a +vector of states or inputs (allows batch simulations). + +Updates the explicit parameterisation if direct parameters +have been updated. +""" +function (m::WrapREN2)(xt::VecOrMat, ut::VecOrMat) + + # Compute explicit parameterisation + if !( (m.params.direct == m.old_params.direct) && + (m.params.output == m.old_params.output) ) + m.explicit = direct_to_explicit(m.params) + m.old_params = deepcopy(m.params) + end + + # Compute update + b = m.explicit.C1 * xt + m.explicit.D12 * ut .+ m.explicit.bv + wt = tril_eq_layer(m.nl, m.explicit.D11, b) + xt1 = m.explicit.A * xt + m.explicit.B1 * wt + m.explicit.B2 * ut .+ m.explicit.bx + yt = m.explicit.C2 * xt + m.explicit.D21 * wt + m.explicit.D22 * ut .+ m.explicit.by + + return xt1, yt + +end \ No newline at end of file From 8078d0cb4dfdc7120a6ef253dad37d33d5bb133a Mon Sep 17 00:00:00 2001 From: nic-barbara Date: Fri, 21 Oct 2022 17:01:52 +1100 Subject: [PATCH 2/4] Added flux-compatible wrapper for REN --- src/RecurrentEquilibriumNetworks.jl | 20 +++++++--- src/Wrappers/diff_ren.jl | 57 +++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 src/Wrappers/diff_ren.jl diff --git a/src/RecurrentEquilibriumNetworks.jl b/src/RecurrentEquilibriumNetworks.jl index 529bfa4..667d737 100644 --- a/src/RecurrentEquilibriumNetworks.jl +++ b/src/RecurrentEquilibriumNetworks.jl @@ -8,11 +8,12 @@ using Flux using LinearAlgebra using MatrixEquations: lyapd, plyapd using Random +using Zygote using Zygote: @adjoint +import Base.:(==) import Flux.gpu, Flux.cpu - ############ Abstract types ############ """ @@ -36,7 +37,11 @@ include("Base/acyclic_ren_solver.jl") include("Base/direct_params.jl") include("Base/output_layer.jl") include("Base/ren.jl") -include("Base/wrapren.jl") + +# Wrappers +include("Wrappers/diff_ren.jl") +include("Wrappers/wrap_ren.jl") +include("Wrappers/wrap_ren_2.jl") # Variations of REN include("ParameterTypes/utils.jl") @@ -52,14 +57,19 @@ include("ParameterTypes/lipschitz_ren.jl") # Types export AbstractREN export AbstractRENParams -export ContractingRENParams + export DirectParams export ExplicitParams -export GeneralRENParams -export LipschitzRENParams export OutputLayer export REN + +export ContractingRENParams +export GeneralRENParams +export LipschitzRENParams + +export DiffREN export WrapREN +export WrapREN2 # Functions export init_states diff --git a/src/Wrappers/diff_ren.jl b/src/Wrappers/diff_ren.jl new file mode 100644 index 0000000..ddf3316 --- /dev/null +++ b/src/Wrappers/diff_ren.jl @@ -0,0 +1,57 @@ +""" +$(TYPEDEF) + +Wrapper for Recurrent Equilibrium Network type which +automatically re-computes explicit parameters every +time the model is called. + +Compatible with Flux.jl +""" +mutable struct DiffREN <: AbstractREN + nl + nu::Int + nx::Int + nv::Int + ny::Int + params::AbstractRENParams + T::DataType +end + +""" + DiffREN(ps::AbstractRENParams) + +Construct DiffREN wrapper from direct parameterisation +""" +function DiffREN(ps::AbstractRENParams{T}) where T + return DiffREN(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, T) +end + +""" + Flux.trainable(m::DiffREN) + +Define trainable parameters for `DiffREN` type +""" +Flux.trainable(m::DiffREN) = Flux.trainable(m.params) + +""" + (m::DiffREN)(xt::VecOrMat, ut::VecOrMat) + +Call the REN given internal states xt and inputs ut. If +function arguments are matrices, each column must be a +vector of states or inputs (allows batch simulations). + +Computes explicit parameterisation each time. This may +be slow if called many times! +""" +function (m::DiffREN)(xt::VecOrMat, ut::VecOrMat) + + explicit = direct_to_explicit(m.params) + + b = explicit.C1 * xt + explicit.D12 * ut .+ explicit.bv + wt = tril_eq_layer(m.nl, explicit.D11, b) + xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ explicit.bx + yt = explicit.C2 * xt + explicit.D21 * wt + explicit.D22 * ut .+ explicit.by + + return xt1, yt + +end \ No newline at end of file From 09aec12f0bc55fff747deafc385f8c89848a24af Mon Sep 17 00:00:00 2001 From: nic-barbara Date: Fri, 21 Oct 2022 17:02:25 +1100 Subject: [PATCH 3/4] Wrote tests for new types and restructured test scripts --- test/{ => ParameterTypes}/contraction.jl | 0 .../general_behavioural_constrains.jl | 0 test/{ => ParameterTypes}/lipschitz_bound.jl | 0 test/Wrappers/wrap_rens.jl | 40 +++++++++++++++++++ test/runtests.jl | 9 +++-- test/wrap_ren.jl | 30 -------------- 6 files changed, 45 insertions(+), 34 deletions(-) rename test/{ => ParameterTypes}/contraction.jl (100%) rename test/{ => ParameterTypes}/general_behavioural_constrains.jl (100%) rename test/{ => ParameterTypes}/lipschitz_bound.jl (100%) create mode 100644 test/Wrappers/wrap_rens.jl delete mode 100644 test/wrap_ren.jl diff --git a/test/contraction.jl b/test/ParameterTypes/contraction.jl similarity index 100% rename from test/contraction.jl rename to test/ParameterTypes/contraction.jl diff --git a/test/general_behavioural_constrains.jl b/test/ParameterTypes/general_behavioural_constrains.jl similarity index 100% rename from test/general_behavioural_constrains.jl rename to test/ParameterTypes/general_behavioural_constrains.jl diff --git a/test/lipschitz_bound.jl b/test/ParameterTypes/lipschitz_bound.jl similarity index 100% rename from test/lipschitz_bound.jl rename to test/ParameterTypes/lipschitz_bound.jl diff --git a/test/Wrappers/wrap_rens.jl b/test/Wrappers/wrap_rens.jl new file mode 100644 index 0000000..9b10b5e --- /dev/null +++ b/test/Wrappers/wrap_rens.jl @@ -0,0 +1,40 @@ +using LinearAlgebra +using Random +using RecurrentEquilibriumNetworks +using Test + +""" +Test REN wrapper with General REN params +""" +batches = 20 +nu, nx, nv, ny = 4, 5, 10, 2 + +Q = Matrix{Float64}(-I(ny)) +R = 0.1^2 * Matrix{Float64}(I(nu)) +S = zeros(Float64, nu, ny) + +ren_ps = GeneralRENParams{Float64}(nu, nx, nv, ny, Q, S, R) +ren1 = WrapREN(ren_ps) +ren2 = WrapREN2(deepcopy(ren_ps)) + +x0 = init_states(ren1, batches) +u0 = randn(nu, batches) + +# Update the model after changing a parameter +old_B2 = deepcopy(ren1.explicit.B2) +ren1.params.direct.B2 .*= rand(size(ren1.params.direct.B2)...) + +x1, y1 = ren1(x0, u0) +update_explicit!(ren1) + +new_B2 = deepcopy(ren1.explicit.B2) +@test old_B2 != new_B2 + +# Test auto-update +old_B2 = deepcopy(ren2.explicit.B2) +ren2.params.direct.B2 .*= rand(size(ren2.params.direct.B2)...) + +x1, y1 = ren2(x0, u0) + +new_B2 = deepcopy(ren2.explicit.B2) +@test old_B2 != new_B2 \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e595db8..4249d93 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,11 +4,12 @@ using Test @testset "RecurrentEquilibriumNetworks.jl" begin # Test a basic example from the README - include("wrap_ren.jl") + include("Wrappers/wrap_ren.jl") + # include("Wrappers/diff_ren.jl") # Test for desired behaviour - include("contraction.jl") - include("general_behavioural_constrains.jl") - include("lipschitz_bound.jl") + include("ParameterTypes/contraction.jl") + include("ParameterTypes/general_behavioural_constrains.jl") + include("ParameterTypes/lipschitz_bound.jl") end diff --git a/test/wrap_ren.jl b/test/wrap_ren.jl deleted file mode 100644 index 9314adf..0000000 --- a/test/wrap_ren.jl +++ /dev/null @@ -1,30 +0,0 @@ -using LinearAlgebra -using Random -using RecurrentEquilibriumNetworks -using Test - -""" -Test REN wrapper with General REN params -""" -batches = 20 -nu, nx, nv, ny = 4, 5, 10, 2 - -Q = Matrix{Float64}(-I(ny)) -R = 0.1^2 * Matrix{Float64}(I(nu)) -S = zeros(Float64, nu, ny) - -ren_ps = GeneralRENParams{Float64}(nu, nx, nv, ny, Q, S, R) -ren = WrapREN(ren_ps) - -x0 = init_states(ren, batches) -u0 = randn(ren.nu, batches) - -x1, y1 = ren(x0, u0) # Evaluates the REN over one timestep - -# Update the model after changing a parameter -old_B2 = deepcopy(ren.explicit.B2) -ren.params.direct.B2 .*= rand(size(ren.params.direct.B2)...) -update_explicit!(ren) -new_B2 = deepcopy(ren.explicit.B2) - -@test old_B2 != new_B2 \ No newline at end of file From 6c1813d3c9eff2be4dbe836ae33c226a0f1437fe Mon Sep 17 00:00:00 2001 From: nic-barbara Date: Fri, 21 Oct 2022 17:09:07 +1100 Subject: [PATCH 4/4] Fixed typo --- test/runtests.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4249d93..63c7e70 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,8 +4,7 @@ using Test @testset "RecurrentEquilibriumNetworks.jl" begin # Test a basic example from the README - include("Wrappers/wrap_ren.jl") - # include("Wrappers/diff_ren.jl") + include("Wrappers/wrap_rens.jl") # Test for desired behaviour include("ParameterTypes/contraction.jl")