-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from acfr/feature/basic-functionality
Feature/basic functionality
- Loading branch information
Showing
13 changed files
with
958 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,44 +3,94 @@ | |
## Status | ||
[![Build Status](https://github.com/nic-barbara/RecurrentEquilibriumNetworks.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/nic-barbara/RecurrentEquilibriumNetworks.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
|
||
## Description | ||
## Package Description | ||
|
||
Julia package for Recurrent Equilibrium Networks. | ||
|
||
[NOTE] This package is a work-in-progress. For now, you may find the following links useful: | ||
- Tutorial on [developing Julia packages](https://julialang.org/contribute/developing_package/) by Chris Rackauckas (MIT) | ||
- Documentation on [managing Julia packages](https://pkgdocs.julialang.org/v1/managing-packages/) and developing unregistered packages with `Pkg.jl` | ||
|
||
So far, the package only contains a couple of test functions. | ||
|
||
## How to use | ||
## Installation for Development | ||
|
||
- Clone the repository into your Julia dev folder: | ||
- For Linux/Mac, use: `git clone [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git RecurrentEquilibriumNetworks` inside your `~/.julia/dev/` directory. | ||
- Note that the repo is `RecurrentEquilibriumNetworks.jl`, but the folder is `RecurrentEquilibriumNetworks`. This is convention for Julia packages. | ||
- Navigate to the repo directory, start the Julia REPL, and type `] activate .` to activate the package. | ||
- Try using the demo functions: | ||
- Type `using RecurrentEquilibriumNetworks` in the REPL to add the package to your current session. | ||
- Test out `test_ren_package()`. It should print `"Hello RecurrentEquilibriumNetworks.jl!"` to your screen. | ||
To install the package for development, clone the repository into your Julia dev folder: | ||
- For Linux/Mac, use `git clone [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git RecurrentEquilibriumNetworks` inside your `~/.julia/dev/` directory. | ||
- Note that the repo is `RecurrentEquilibriumNetworks.jl`, but the folder is `RecurrentEquilibriumNetworks`. This is convention for Julia packages. | ||
|
||
## To contribute to the package: | ||
Navigate to the repository directory, start the Julia REPL, and type `] activate .` to activate the package. You can now test out some basic functionality: | ||
|
||
```julia | ||
using Random | ||
using RecurrentEquilibriumNetworks | ||
|
||
- Edit source files in `src/`: | ||
- `RecurrentEquilibriumNetworks.jl` is the main file. Include dependencies and export types/functions here. | ||
- Add other source files for new functionality (eg: `src/functions.jl`). | ||
- Write testing scripts for the package in `test/`: | ||
- See [`Test.jl`](https://docs.julialang.org/en/v1/stdlib/Test/) documentation for writing tests | ||
- Run tests with `] test` | ||
- Use git to pull/push changes to the package as normal | ||
batches = 50 | ||
nu, nx, nv, ny = 4, 10, 20, 2 | ||
|
||
## To use the package in a separate Julia workspace: | ||
contracting_ren_ps = ContractingRENParams{Float64}(nu, nx, nv, ny) | ||
contracting_ren = REN(ren_ps) | ||
|
||
x0 = init_states(contracting_ren, batches) | ||
u0 = randn(contracting_ren.nu, batches) | ||
|
||
- Add development version of the package with: `] dev [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git` | ||
x1, y1 = ren(x0, u0) # Evaluates the REN over one timestep | ||
|
||
println(x1) | ||
println(y1) | ||
``` | ||
|
||
|
||
## Contributing to the Package | ||
|
||
The main file is `src/RecurrentEquilibriumNetworks.jl`. This imports all relevant packages, defines abstract types, includes code from other files, and exports the necessary components of our package. All `using PackageName` statements should be included in this file. As a general guide: | ||
- Only import packages you really need | ||
- If you only need one function from a package, import it explicitly (not the whole package) | ||
|
||
When including files in our `src/` folder, the order often matters. I have tried to structure the `include` statements in `RecurrentEquilibriumNetworks.jl` so that we only ever have to include code once, in the main file. Please follow the conventioned outlined in the comments. | ||
|
||
The source files for our package are al in the `src/` folder, and are split between `src/Base/` and `src/ParameterTypes/`. The `Base/` folder should contain code relevant to the core functionality of this package. The `ParameterTypes/` is where to add different versions of REN (eg: contracting REN, Lipschitz-bounded REN, etc.). See `src/ParameterTypes/general_ren.jl` for an example. | ||
|
||
Once you have written any code for this package, be sure to test it thoroughly. Write testing scripts for the package in `test/`: | ||
- See [`Test.jl`](https://docs.julialang.org/en/v1/stdlib/Test/) documentation for writing tests | ||
- Run all tests for the package with `] test` | ||
|
||
Use git to pull/push changes to the package as normal while developing it. | ||
|
||
|
||
## Using the package | ||
|
||
Once the package is functional, it can be used in other Julia workspaces like any normal package by following these instructions: | ||
|
||
- Add a development version of the package with: `] dev [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git` | ||
- This is instead of the usual `] add` command. We also have to use the git link not the package name because it is an unregistered Julia package. | ||
- Whenever you use the package, it will access the latest version in your `.julia/dev/` folder rather than the stable release in the `main` branch. This is easiest for development while we frequently change the package. | ||
- To use the code on the current main branch of the repo (latest stable release), instead type `] add [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git`. You will have to manually update the package as normal with `] update RecurrentEquilibriumNetworks`. | ||
|
||
### Contact | ||
Nic Barbara ([email protected]) for any questions/concerns. | ||
|
||
## Some Early Documentation | ||
The package is structured around the `REN` type. An object of type `REN` has the following attributes: | ||
- explicit model struct | ||
- in/out, state/nl sizes | ||
- nonlinearity | ||
|
||
and functions to build/use it as follows: | ||
- A constructor | ||
- A self-call method | ||
- A function to initialise a state vector, `x0 = init_state(batches)` | ||
- A function to set the output to zero, `set_output_zero!(ren)` | ||
|
||
Each `REN` is constructed from a direct (implicit) parameterisation of the REN architecture. Each variation of REN (eg: contracting, passive, Lipschitz bounded) is a subtype of `AbstractRENParams`, an abstract type. This encodes all information required to build a `REN` satisfying some set of behavioural constraints. Each subtype must include the following attributes: | ||
- in/out, state/nl sizes `nu, ny, nx, nv` | ||
- output layer of type `OutputLayer` | ||
- direct (implicit) parameters of type `DirectParams` | ||
- Any other attributes relevant to the parameterisation. Eg: `Q, S, R, alpha_bar` for a general REN | ||
|
||
The output layer and implicit parameters are structs defined in `src/Base/output_layer.jl` and `src/Base/direct_params.jl` (respectively). Each subtype of `AbstractRENParams` must also have the following methods: | ||
- A constructor | ||
- A definition of `Flux.trainable()` specifying the trainable parameters | ||
- A definition of `direct_to_explicit()` to convert the direct paramterisation to its explicit form | ||
|
||
See `src/ParameterTypes/general_ren.jl` for an example. | ||
|
||
## Contact | ||
Nic Barbara ([email protected]) for any questions/concerns. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Package Architecture | ||
|
||
This document is a work-in-progress. Use it to start our documentation of the `RecurrentEquilibriumNetworks.jl` package. I'm using this document to keep track of how I'm writing the package. It is not necessarily up-to-date or correct. | ||
|
||
|
||
|
||
|
||
|
||
## TODO Lists (and other useful things) | ||
|
||
### General: | ||
- Write direct_to_explicit functions for and `GeneralRENParams` | ||
- Add documentation and improve speed for `Base/acyclic_ren_solver.jl` code taken from Max's work | ||
|
||
### Changes from previous code: | ||
- Currently no method to construct a REN without specifying type. Seems good to force it | ||
- Treatment of `D22`: | ||
- `D22` can be included directly as a trainable parameter in the `OutputLayer`, but is **NOT** by default. Have to set `D22_trainable = true` on construction | ||
- `D22` can be parameterised by free parameters `(X3,Y3,Z3)` by setting `D22_free` in `DirectParams` | ||
- Added \alpha_bar to `ContractingRENParams` (as an example). Will use this to set contraction rate in explicit params construction | ||
|
||
### Bugs: | ||
- Cholesky initialisation of `DirectParams` | ||
|
||
### Functionality | ||
- How to include different initialisation methods? | ||
- Write full support for GPU/CUDA arrays | ||
|
||
### Conversion questions: | ||
- `D22` terms in `output.jl` initialised as 0. Good idea in general? | ||
- `sample_ff_ren` had `randn(2nx + nv, 2nx + nv) / sqrt(2nx + nv)`, not `.../ sqrt(2*(2nx + nv))`. Why? | ||
- Why have `bx_scale` and `bv_scale` in constructor for `DirectParams`? | ||
- Constructor for `implicit_ff_cell` has zeros for some params, random for others. Why? | ||
- Construction of S (now `Y1`) in `ffREN.jl` line 130/136 divides by 2 an extra time. Not necessary, right? Does this change anything? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
solve_tril_layer(ϕ, W::Matrix, b::VecOrMat) | ||
Solves z = ϕ.(W*z .+ b) for lower-triangular W, where | ||
ϕ is a either a ReLU or tanh activation function. | ||
""" | ||
function solve_tril_layer(ϕ::Union{typeof(Flux.relu), typeof(Flux.tanh)}, W::Matrix, b::VecOrMat) | ||
z_eq = typeof(b)(zeros(size(b))) | ||
for i in 1:size(b,1) | ||
Wi = @view W[i:i, 1:i - 1] | ||
zi = @view z_eq[1:i-1,:] | ||
bi = @view b[i:i, :] | ||
z_eq[i:i,:] .= ϕ.(Wi * zi .+ bi) | ||
end | ||
return z_eq | ||
end | ||
|
||
""" | ||
solve_tril_layer(ϕ, W::Matrix, b::VecOrMat) | ||
Solves z = ϕ.(W*z .+ b) for lower-triangular W, where | ||
ϕ is a generic static nonlinearity. | ||
""" | ||
function solve_tril_layer(ϕ, W::Matrix, b::VecOrMat) | ||
|
||
# Slower to not specify typeof(ϕ), which is why this is separate | ||
println("Using non-ReLU/tanh version of solve_tril_layer()") | ||
z_eq = typeof(b)(zeros(size(b))) | ||
for i in 1:size(b,1) | ||
Wi = @view W[i:i, 1:i - 1] | ||
zi = @view z_eq[1:i-1,:] | ||
bi = @view b[i:i, :] | ||
z_eq[i:i,:] .= ϕ.(Wi * zi .+ bi) | ||
end | ||
return z_eq | ||
end | ||
|
||
# TODO: Add documentation for everything below here. This is all from Max's code, untouched so far | ||
# TODO: Can also speed things up by specifying function argument types | ||
|
||
""" | ||
tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq; tol=1E-9) | ||
Calculate gradients for solving lower-triangular equilibirum | ||
network layer. | ||
""" | ||
function tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq; tol=1E-9) | ||
one_vec = typeof(b)(ones(size(b))) | ||
v = W * zeq + b | ||
j = Zygote.pullback(z -> ϕ.(z), v)[2](one_vec)[1] | ||
# J = Diagonal(j[:]) | ||
|
||
eval_grad(t) = (I - (j[:, t] .* W))' \ Δz[:, t] | ||
gn = reduce(hcat, eval_grad(t) for t in 1:size(b, 2)) | ||
|
||
return nothing, nothing, nothing, gn | ||
end | ||
tril_layer_backward(ϕ, W, b, zeq) = zeq | ||
|
||
@adjoint solve_tril_layer(ϕ, W, b) = solve_tril_layer(ϕ, W, b), Δz -> (nothing, nothing, nothing) | ||
@adjoint tril_layer_backward(ϕ, W, b, zeq) = tril_layer_backward(ϕ, W, b, zeq), Δz -> tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq) | ||
|
||
function tril_eq_layer(ϕ, W, b) | ||
weq = solve_tril_layer(ϕ, W, b) | ||
# TODO: return weq if not differentiating anything | ||
weq1 = ϕ.(W * weq + b) # Run forward and track grads | ||
return tril_layer_backward(ϕ, W, b, weq1) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
""" | ||
$(TYPEDEF) | ||
Direct (implicit) parameters used to construct a REN. | ||
""" | ||
mutable struct DirectParams{T} | ||
ρ::Union{Vector{T},CuVector{T}} # used in polar param | ||
V::Union{Matrix{T},CuMatrix{T}} | ||
Y1::Union{Matrix{T},CuMatrix{T}} | ||
X3::Union{Matrix{T},CuMatrix{T}} | ||
Y3::Union{Matrix{T},CuMatrix{T}} | ||
Z3::Union{Matrix{T},CuMatrix{T}} | ||
B2::Union{Matrix{T},CuMatrix{T}} | ||
D12::Union{Matrix{T},CuMatrix{T}} | ||
bx::Union{Vector{T},CuVector{T}} | ||
bv::Union{Vector{T},CuVector{T}} | ||
ϵ::T | ||
polar_param::Bool # Whether or not to use polar param | ||
D22_free::Bool # Is D22 free or parameterised by (X3,Y3,Z3)? | ||
end | ||
|
||
""" | ||
DirectParams{T}(nu, nx, nv; ...) | ||
Constructor for `DirectParams` struct. Allows for the following | ||
initialisation methods, specified as symbols by `init` argument: | ||
- `:random`: Random sampling for all parameters | ||
- `:cholesky`: Compute `V` with cholesky factorisation of `H`, sets `E,F,P = I` | ||
Option `D22_free` specifies whether or not to include parameters X3, Y3, and Z3 | ||
as trainable parameters used in the explicit construction of D22. If `D22_free == true` | ||
then `(X3,Y3,Z3)` are empty and not trainable. | ||
Note that `D22_free = false` by default. | ||
""" | ||
function DirectParams{T}( | ||
nu::Int, nx::Int, nv::Int, ny::Int; | ||
init = :random, | ||
ϵ = T(0.001), | ||
bx_scale = T(0), | ||
bv_scale = T(1), | ||
polar_param = false, | ||
D22_free = false, | ||
rng = Random.GLOBAL_RNG | ||
) where T | ||
|
||
# Random sampling | ||
if init == :random | ||
|
||
B2 = glorot_normal(nx, nu; T=T, rng=rng) | ||
D12 = glorot_normal(nv, nu; T=T, rng=rng) | ||
|
||
ρ = zeros(1) | ||
|
||
# Make orthogonal V | ||
V = glorot_normal(2nx + nv, 2nx + nv; T=T, rng=rng) | ||
V = Matrix(qr(V).Q) | ||
|
||
# Specify H and compute V | ||
elseif init == :cholesky | ||
|
||
E = Matrix{T}(I, nx, nx) | ||
F = Matrix{T}(I, nx, nx) | ||
P = Matrix{T}(I, nx, nx) | ||
|
||
B1 = zeros(T, nx, nv) | ||
B2 = glorot_normal(nx, nu; T=T, rng=rng) | ||
|
||
C1 = zeros(T, nv, nx) | ||
D11 = glorot_normal(nv, nv; T=T, rng=rng) | ||
D12 = zeros(T, nv, nu) | ||
|
||
# TODO: This is prone to errors. Needs a bugfix! | ||
Λ = 2*I | ||
H22 = 2Λ - D11 - D11' | ||
Htild = [(E + E' - P) -C1' F'; | ||
-C1 H22 B1' | ||
F B1 P] + ϵ * I | ||
|
||
ρ = zeros(T, 1) | ||
V = Matrix{T}(cholesky(Htild).U) # H = V'*V | ||
|
||
else | ||
error("Undefined initialisation method ", init) | ||
end | ||
|
||
# Free parameter for E | ||
Y1 = glorot_normal(nx, nx; T=T, rng=rng) | ||
|
||
# Parameters for D22 in output layer | ||
if D22_free | ||
X3 = zeros(T, 0, 0) | ||
Y3 = zeros(T, 0, 0) | ||
Z3 = zeros(T, 0, 0) | ||
else | ||
X3 = glorot_normal(nu, nu; T=T, rng=rng) | ||
Y3 = glorot_normal(nu, nu; T=T, rng=rng) | ||
Z3 = glorot_normal(abs(ny - nu), ny; T=T, rng=rng) | ||
end | ||
|
||
# Bias terms | ||
bv = T(bv_scale) * glorot_normal(nv; T=T, rng=rng) | ||
bx = T(bx_scale) * glorot_normal(nx; T=T, rng=rng) | ||
|
||
return DirectParams( | ||
ρ ,V, | ||
Y1, X3, Y3, Z3, | ||
B2, D12, bx, bv, T(ϵ), | ||
polar_param, D22_free | ||
) | ||
end | ||
|
||
""" | ||
Flux.trainable(L::DirectParams) | ||
Define trainable parameters for `DirectParams` type. | ||
Filter empty ones (handy when nx=0) | ||
""" | ||
function Flux.trainable(L::DirectParams) | ||
if L.D22_free | ||
return filter( | ||
p -> length(p) !=0, | ||
[L.ρ, L.V, L.Y1, L.B2, L.D12, L.bx, L.bv] | ||
) | ||
end | ||
return filter( | ||
p -> length(p) !=0, | ||
[L.ρ, L.V, L.Y1, L.X3, L.Y3, L.Z3, L.B2, L.D12, L.bx, L.bv] | ||
) | ||
end | ||
|
||
""" | ||
Flux.gpu(M::DirectParams{T}) where T | ||
Add GPU compatibility for `DirectParams` type | ||
""" | ||
function Flux.gpu(M::DirectParams{T}) where T | ||
if T != Float32 | ||
println("Moving type: ", T, " to gpu may not be supported. Try Float32!") | ||
end | ||
return DirectParams{T}( | ||
gpu(M.V), gpu(M.Y1), gpu(M.X3), gpu(M.Y3), | ||
gpu(M.Z3), gpu(M.B2), gpu(M.D12), gpu(M.bx), | ||
gpu(M.bv), M.ϵ, M.polar_param | ||
) | ||
end | ||
|
||
""" | ||
Flux.cpu(M::DirectParams{T}) where T | ||
Add CPU compatibility for `DirectParams` type | ||
""" | ||
function Flux.cpu(M::DirectParams{T}) where T | ||
return DirectParams{T}( | ||
cpu(M.V), cpu(M.Y1), cpu(M.X3), cpu(M.Y3), | ||
cpu(M.Z3), cpu(M.B2), cpu(M.D12), cpu(M.bx), | ||
cpu(M.bv), M.ϵ, M.polar_param | ||
) | ||
end |
Oops, something went wrong.