-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from acfr/feature/dev-wrappers
Feature/dev wrappers
- Loading branch information
Showing
12 changed files
with
247 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
$(TYPEDEF) | ||
Wrapper for Recurrent Equilibrium Network type which | ||
automatically re-computes explicit parameters every | ||
time the model is called. | ||
Compatible with Flux.jl | ||
""" | ||
mutable struct DiffREN <: AbstractREN | ||
nl | ||
nu::Int | ||
nx::Int | ||
nv::Int | ||
ny::Int | ||
params::AbstractRENParams | ||
T::DataType | ||
end | ||
|
||
""" | ||
DiffREN(ps::AbstractRENParams) | ||
Construct DiffREN wrapper from direct parameterisation | ||
""" | ||
function DiffREN(ps::AbstractRENParams{T}) where T | ||
return DiffREN(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, T) | ||
end | ||
|
||
""" | ||
Flux.trainable(m::DiffREN) | ||
Define trainable parameters for `DiffREN` type | ||
""" | ||
Flux.trainable(m::DiffREN) = Flux.trainable(m.params) | ||
|
||
""" | ||
(m::DiffREN)(xt::VecOrMat, ut::VecOrMat) | ||
Call the REN given internal states xt and inputs ut. If | ||
function arguments are matrices, each column must be a | ||
vector of states or inputs (allows batch simulations). | ||
Computes explicit parameterisation each time. This may | ||
be slow if called many times! | ||
""" | ||
function (m::DiffREN)(xt::VecOrMat, ut::VecOrMat) | ||
|
||
explicit = direct_to_explicit(m.params) | ||
|
||
b = explicit.C1 * xt + explicit.D12 * ut .+ explicit.bv | ||
wt = tril_eq_layer(m.nl, explicit.D11, b) | ||
xt1 = explicit.A * xt + explicit.B1 * wt + explicit.B2 * ut .+ explicit.bx | ||
yt = explicit.C2 * xt + explicit.D21 * wt + explicit.D22 * ut .+ explicit.by | ||
|
||
return xt1, yt | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
""" | ||
$(TYPEDEF) | ||
Wrapper for Recurrent Equilibrium Network type which | ||
automatically re-computes explicit parameters whenever | ||
the direct parameters are edited. | ||
Not compatible with Flux.jl | ||
""" | ||
mutable struct WrapREN2 <: AbstractREN | ||
nl | ||
nu::Int | ||
nx::Int | ||
nv::Int | ||
ny::Int | ||
params::AbstractRENParams | ||
explicit::ExplicitParams | ||
old_params::AbstractRENParams | ||
T::DataType | ||
end | ||
|
||
""" | ||
WrapREN2(ps::AbstractRENParams) | ||
Construct WrapREN2 wrapper from direct parameterisation | ||
""" | ||
function WrapREN2(ps::AbstractRENParams{T}) where T | ||
explicit = direct_to_explicit(ps) | ||
old_ps = deepcopy(ps) | ||
return WrapREN2(ps.nl, ps.nu, ps.nx, ps.nv, ps.ny, ps, explicit, old_ps, T) | ||
end | ||
|
||
""" | ||
Flux.trainable(m::WrapREN2) | ||
Define trainable parameters for `WrapREN2` type | ||
""" | ||
Flux.trainable(m::WrapREN2) = Flux.trainable(m.params) | ||
|
||
""" | ||
(m::WrapREN2)(xt::VecOrMat, ut::VecOrMat) | ||
Call the REN given internal states xt and inputs ut. If | ||
function arguments are matrices, each column must be a | ||
vector of states or inputs (allows batch simulations). | ||
Updates the explicit parameterisation if direct parameters | ||
have been updated. | ||
""" | ||
function (m::WrapREN2)(xt::VecOrMat, ut::VecOrMat) | ||
|
||
# Compute explicit parameterisation | ||
if !( (m.params.direct == m.old_params.direct) && | ||
(m.params.output == m.old_params.output) ) | ||
m.explicit = direct_to_explicit(m.params) | ||
m.old_params = deepcopy(m.params) | ||
end | ||
|
||
# Compute update | ||
b = m.explicit.C1 * xt + m.explicit.D12 * ut .+ m.explicit.bv | ||
wt = tril_eq_layer(m.nl, m.explicit.D11, b) | ||
xt1 = m.explicit.A * xt + m.explicit.B1 * wt + m.explicit.B2 * ut .+ m.explicit.bx | ||
yt = m.explicit.C2 * xt + m.explicit.D21 * wt + m.explicit.D22 * ut .+ m.explicit.by | ||
|
||
return xt1, yt | ||
|
||
end |
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
using LinearAlgebra | ||
using Random | ||
using RecurrentEquilibriumNetworks | ||
using Test | ||
|
||
""" | ||
Test REN wrapper with General REN params | ||
""" | ||
batches = 20 | ||
nu, nx, nv, ny = 4, 5, 10, 2 | ||
|
||
Q = Matrix{Float64}(-I(ny)) | ||
R = 0.1^2 * Matrix{Float64}(I(nu)) | ||
S = zeros(Float64, nu, ny) | ||
|
||
ren_ps = GeneralRENParams{Float64}(nu, nx, nv, ny, Q, S, R) | ||
ren1 = WrapREN(ren_ps) | ||
ren2 = WrapREN2(deepcopy(ren_ps)) | ||
|
||
x0 = init_states(ren1, batches) | ||
u0 = randn(nu, batches) | ||
|
||
# Update the model after changing a parameter | ||
old_B2 = deepcopy(ren1.explicit.B2) | ||
ren1.params.direct.B2 .*= rand(size(ren1.params.direct.B2)...) | ||
|
||
x1, y1 = ren1(x0, u0) | ||
update_explicit!(ren1) | ||
|
||
new_B2 = deepcopy(ren1.explicit.B2) | ||
@test old_B2 != new_B2 | ||
|
||
# Test auto-update | ||
old_B2 = deepcopy(ren2.explicit.B2) | ||
ren2.params.direct.B2 .*= rand(size(ren2.params.direct.B2)...) | ||
|
||
x1, y1 = ren2(x0, u0) | ||
|
||
new_B2 = deepcopy(ren2.explicit.B2) | ||
@test old_B2 != new_B2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.