Skip to content

Commit

Permalink
Merge pull request #1 from acfr/feature/basic-functionality
Browse files Browse the repository at this point in the history
Feature/basic functionality
  • Loading branch information
nic-barbara authored Sep 28, 2022
2 parents 3f030aa + a33be6b commit 9b61781
Show file tree
Hide file tree
Showing 13 changed files with 958 additions and 34 deletions.
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ uuid = "a1f18e6b-8af1-433f-a85d-2e1ee636a2b8"
authors = ["Nicholas Barbara", "Ruigang Wang", "Jing Cheng", "Jerome Justin", "Ian Manchester"]
version = "0.1.0"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixEquations = "99c1a7ee-ab34-5fd5-8076-27c950a045f4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
julia = "1"

Expand Down
94 changes: 72 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,94 @@
## Status
[![Build Status](https://github.com/nic-barbara/RecurrentEquilibriumNetworks.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/nic-barbara/RecurrentEquilibriumNetworks.jl/actions/workflows/CI.yml?query=branch%3Amain)

## Description
## Package Description

Julia package for Recurrent Equilibrium Networks.

[NOTE] This package is a work-in-progress. For now, you may find the following links useful:
- Tutorial on [developing Julia packages](https://julialang.org/contribute/developing_package/) by Chris Rackauckas (MIT)
- Documentation on [managing Julia packages](https://pkgdocs.julialang.org/v1/managing-packages/) and developing unregistered packages with `Pkg.jl`

So far, the package only contains a couple of test functions.

## How to use
## Installation for Development

- Clone the repository into your Julia dev folder:
- For Linux/Mac, use: `git clone [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git RecurrentEquilibriumNetworks` inside your `~/.julia/dev/` directory.
- Note that the repo is `RecurrentEquilibriumNetworks.jl`, but the folder is `RecurrentEquilibriumNetworks`. This is convention for Julia packages.
- Navigate to the repo directory, start the Julia REPL, and type `] activate .` to activate the package.
- Try using the demo functions:
- Type `using RecurrentEquilibriumNetworks` in the REPL to add the package to your current session.
- Test out `test_ren_package()`. It should print `"Hello RecurrentEquilibriumNetworks.jl!"` to your screen.
To install the package for development, clone the repository into your Julia dev folder:
- For Linux/Mac, use `git clone [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git RecurrentEquilibriumNetworks` inside your `~/.julia/dev/` directory.
- Note that the repo is `RecurrentEquilibriumNetworks.jl`, but the folder is `RecurrentEquilibriumNetworks`. This is convention for Julia packages.

## To contribute to the package:
Navigate to the repository directory, start the Julia REPL, and type `] activate .` to activate the package. You can now test out some basic functionality:

```julia
using Random
using RecurrentEquilibriumNetworks

- Edit source files in `src/`:
- `RecurrentEquilibriumNetworks.jl` is the main file. Include dependencies and export types/functions here.
- Add other source files for new functionality (eg: `src/functions.jl`).
- Write testing scripts for the package in `test/`:
- See [`Test.jl`](https://docs.julialang.org/en/v1/stdlib/Test/) documentation for writing tests
- Run tests with `] test`
- Use git to pull/push changes to the package as normal
batches = 50
nu, nx, nv, ny = 4, 10, 20, 2

## To use the package in a separate Julia workspace:
contracting_ren_ps = ContractingRENParams{Float64}(nu, nx, nv, ny)
contracting_ren = REN(ren_ps)

x0 = init_states(contracting_ren, batches)
u0 = randn(contracting_ren.nu, batches)

- Add development version of the package with: `] dev [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git`
x1, y1 = ren(x0, u0) # Evaluates the REN over one timestep

println(x1)
println(y1)
```


## Contributing to the Package

The main file is `src/RecurrentEquilibriumNetworks.jl`. This imports all relevant packages, defines abstract types, includes code from other files, and exports the necessary components of our package. All `using PackageName` statements should be included in this file. As a general guide:
- Only import packages you really need
- If you only need one function from a package, import it explicitly (not the whole package)

When including files in our `src/` folder, the order often matters. I have tried to structure the `include` statements in `RecurrentEquilibriumNetworks.jl` so that we only ever have to include code once, in the main file. Please follow the conventioned outlined in the comments.

The source files for our package are al in the `src/` folder, and are split between `src/Base/` and `src/ParameterTypes/`. The `Base/` folder should contain code relevant to the core functionality of this package. The `ParameterTypes/` is where to add different versions of REN (eg: contracting REN, Lipschitz-bounded REN, etc.). See `src/ParameterTypes/general_ren.jl` for an example.

Once you have written any code for this package, be sure to test it thoroughly. Write testing scripts for the package in `test/`:
- See [`Test.jl`](https://docs.julialang.org/en/v1/stdlib/Test/) documentation for writing tests
- Run all tests for the package with `] test`

Use git to pull/push changes to the package as normal while developing it.


## Using the package

Once the package is functional, it can be used in other Julia workspaces like any normal package by following these instructions:

- Add a development version of the package with: `] dev [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git`
- This is instead of the usual `] add` command. We also have to use the git link not the package name because it is an unregistered Julia package.
- Whenever you use the package, it will access the latest version in your `.julia/dev/` folder rather than the stable release in the `main` branch. This is easiest for development while we frequently change the package.
- To use the code on the current main branch of the repo (latest stable release), instead type `] add [email protected]:acfr/RecurrentEquilibriumNetworks.jl.git`. You will have to manually update the package as normal with `] update RecurrentEquilibriumNetworks`.

### Contact
Nic Barbara ([email protected]) for any questions/concerns.

## Some Early Documentation
The package is structured around the `REN` type. An object of type `REN` has the following attributes:
- explicit model struct
- in/out, state/nl sizes
- nonlinearity

and functions to build/use it as follows:
- A constructor
- A self-call method
- A function to initialise a state vector, `x0 = init_state(batches)`
- A function to set the output to zero, `set_output_zero!(ren)`

Each `REN` is constructed from a direct (implicit) parameterisation of the REN architecture. Each variation of REN (eg: contracting, passive, Lipschitz bounded) is a subtype of `AbstractRENParams`, an abstract type. This encodes all information required to build a `REN` satisfying some set of behavioural constraints. Each subtype must include the following attributes:
- in/out, state/nl sizes `nu, ny, nx, nv`
- output layer of type `OutputLayer`
- direct (implicit) parameters of type `DirectParams`
- Any other attributes relevant to the parameterisation. Eg: `Q, S, R, alpha_bar` for a general REN

The output layer and implicit parameters are structs defined in `src/Base/output_layer.jl` and `src/Base/direct_params.jl` (respectively). Each subtype of `AbstractRENParams` must also have the following methods:
- A constructor
- A definition of `Flux.trainable()` specifying the trainable parameters
- A definition of `direct_to_explicit()` to convert the direct paramterisation to its explicit form

See `src/ParameterTypes/general_ren.jl` for an example.

## Contact
Nic Barbara ([email protected]) for any questions/concerns.
34 changes: 34 additions & 0 deletions docs/nic_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Package Architecture

This document is a work-in-progress. Use it to start our documentation of the `RecurrentEquilibriumNetworks.jl` package. I'm using this document to keep track of how I'm writing the package. It is not necessarily up-to-date or correct.





## TODO Lists (and other useful things)

### General:
- Write direct_to_explicit functions for and `GeneralRENParams`
- Add documentation and improve speed for `Base/acyclic_ren_solver.jl` code taken from Max's work

### Changes from previous code:
- Currently no method to construct a REN without specifying type. Seems good to force it
- Treatment of `D22`:
- `D22` can be included directly as a trainable parameter in the `OutputLayer`, but is **NOT** by default. Have to set `D22_trainable = true` on construction
- `D22` can be parameterised by free parameters `(X3,Y3,Z3)` by setting `D22_free` in `DirectParams`
- Added \alpha_bar to `ContractingRENParams` (as an example). Will use this to set contraction rate in explicit params construction

### Bugs:
- Cholesky initialisation of `DirectParams`

### Functionality
- How to include different initialisation methods?
- Write full support for GPU/CUDA arrays

### Conversion questions:
- `D22` terms in `output.jl` initialised as 0. Good idea in general?
- `sample_ff_ren` had `randn(2nx + nv, 2nx + nv) / sqrt(2nx + nv)`, not `.../ sqrt(2*(2nx + nv))`. Why?
- Why have `bx_scale` and `bv_scale` in constructor for `DirectParams`?
- Constructor for `implicit_ff_cell` has zeros for some params, random for others. Why?
- Construction of S (now `Y1`) in `ffREN.jl` line 130/136 divides by 2 an extra time. Not necessary, right? Does this change anything?
68 changes: 68 additions & 0 deletions src/Base/acyclic_ren_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
solve_tril_layer(ϕ, W::Matrix, b::VecOrMat)
Solves z = ϕ.(W*z .+ b) for lower-triangular W, where
ϕ is a either a ReLU or tanh activation function.
"""
function solve_tril_layer::Union{typeof(Flux.relu), typeof(Flux.tanh)}, W::Matrix, b::VecOrMat)
z_eq = typeof(b)(zeros(size(b)))
for i in 1:size(b,1)
Wi = @view W[i:i, 1:i - 1]
zi = @view z_eq[1:i-1,:]
bi = @view b[i:i, :]
z_eq[i:i,:] .= ϕ.(Wi * zi .+ bi)
end
return z_eq
end

"""
solve_tril_layer(ϕ, W::Matrix, b::VecOrMat)
Solves z = ϕ.(W*z .+ b) for lower-triangular W, where
ϕ is a generic static nonlinearity.
"""
function solve_tril_layer(ϕ, W::Matrix, b::VecOrMat)

# Slower to not specify typeof(ϕ), which is why this is separate
println("Using non-ReLU/tanh version of solve_tril_layer()")
z_eq = typeof(b)(zeros(size(b)))
for i in 1:size(b,1)
Wi = @view W[i:i, 1:i - 1]
zi = @view z_eq[1:i-1,:]
bi = @view b[i:i, :]
z_eq[i:i,:] .= ϕ.(Wi * zi .+ bi)
end
return z_eq
end

# TODO: Add documentation for everything below here. This is all from Max's code, untouched so far
# TODO: Can also speed things up by specifying function argument types

"""
tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq; tol=1E-9)
Calculate gradients for solving lower-triangular equilibirum
network layer.
"""
function tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq; tol=1E-9)
one_vec = typeof(b)(ones(size(b)))
v = W * zeq + b
j = Zygote.pullback(z -> ϕ.(z), v)[2](one_vec)[1]
# J = Diagonal(j[:])

eval_grad(t) = (I - (j[:, t] .* W))' \ Δz[:, t]
gn = reduce(hcat, eval_grad(t) for t in 1:size(b, 2))

return nothing, nothing, nothing, gn
end
tril_layer_backward(ϕ, W, b, zeq) = zeq

@adjoint solve_tril_layer(ϕ, W, b) = solve_tril_layer(ϕ, W, b), Δz -> (nothing, nothing, nothing)
@adjoint tril_layer_backward(ϕ, W, b, zeq) = tril_layer_backward(ϕ, W, b, zeq), Δz -> tril_layer_calculate_gradient(Δz, ϕ, W, b, zeq)

function tril_eq_layer(ϕ, W, b)
weq = solve_tril_layer(ϕ, W, b)
# TODO: return weq if not differentiating anything
weq1 = ϕ.(W * weq + b) # Run forward and track grads
return tril_layer_backward(ϕ, W, b, weq1)
end
158 changes: 158 additions & 0 deletions src/Base/direct_params.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""
$(TYPEDEF)
Direct (implicit) parameters used to construct a REN.
"""
mutable struct DirectParams{T}
ρ::Union{Vector{T},CuVector{T}} # used in polar param
V::Union{Matrix{T},CuMatrix{T}}
Y1::Union{Matrix{T},CuMatrix{T}}
X3::Union{Matrix{T},CuMatrix{T}}
Y3::Union{Matrix{T},CuMatrix{T}}
Z3::Union{Matrix{T},CuMatrix{T}}
B2::Union{Matrix{T},CuMatrix{T}}
D12::Union{Matrix{T},CuMatrix{T}}
bx::Union{Vector{T},CuVector{T}}
bv::Union{Vector{T},CuVector{T}}
ϵ::T
polar_param::Bool # Whether or not to use polar param
D22_free::Bool # Is D22 free or parameterised by (X3,Y3,Z3)?
end

"""
DirectParams{T}(nu, nx, nv; ...)
Constructor for `DirectParams` struct. Allows for the following
initialisation methods, specified as symbols by `init` argument:
- `:random`: Random sampling for all parameters
- `:cholesky`: Compute `V` with cholesky factorisation of `H`, sets `E,F,P = I`
Option `D22_free` specifies whether or not to include parameters X3, Y3, and Z3
as trainable parameters used in the explicit construction of D22. If `D22_free == true`
then `(X3,Y3,Z3)` are empty and not trainable.
Note that `D22_free = false` by default.
"""
function DirectParams{T}(
nu::Int, nx::Int, nv::Int, ny::Int;
init = :random,
ϵ = T(0.001),
bx_scale = T(0),
bv_scale = T(1),
polar_param = false,
D22_free = false,
rng = Random.GLOBAL_RNG
) where T

# Random sampling
if init == :random

B2 = glorot_normal(nx, nu; T=T, rng=rng)
D12 = glorot_normal(nv, nu; T=T, rng=rng)

ρ = zeros(1)

# Make orthogonal V
V = glorot_normal(2nx + nv, 2nx + nv; T=T, rng=rng)
V = Matrix(qr(V).Q)

# Specify H and compute V
elseif init == :cholesky

E = Matrix{T}(I, nx, nx)
F = Matrix{T}(I, nx, nx)
P = Matrix{T}(I, nx, nx)

B1 = zeros(T, nx, nv)
B2 = glorot_normal(nx, nu; T=T, rng=rng)

C1 = zeros(T, nv, nx)
D11 = glorot_normal(nv, nv; T=T, rng=rng)
D12 = zeros(T, nv, nu)

# TODO: This is prone to errors. Needs a bugfix!
Λ = 2*I
H22 = 2Λ - D11 - D11'
Htild = [(E + E' - P) -C1' F';
-C1 H22 B1'
F B1 P] + ϵ * I

ρ = zeros(T, 1)
V = Matrix{T}(cholesky(Htild).U) # H = V'*V

else
error("Undefined initialisation method ", init)
end

# Free parameter for E
Y1 = glorot_normal(nx, nx; T=T, rng=rng)

# Parameters for D22 in output layer
if D22_free
X3 = zeros(T, 0, 0)
Y3 = zeros(T, 0, 0)
Z3 = zeros(T, 0, 0)
else
X3 = glorot_normal(nu, nu; T=T, rng=rng)
Y3 = glorot_normal(nu, nu; T=T, rng=rng)
Z3 = glorot_normal(abs(ny - nu), ny; T=T, rng=rng)
end

# Bias terms
bv = T(bv_scale) * glorot_normal(nv; T=T, rng=rng)
bx = T(bx_scale) * glorot_normal(nx; T=T, rng=rng)

return DirectParams(
ρ ,V,
Y1, X3, Y3, Z3,
B2, D12, bx, bv, T(ϵ),
polar_param, D22_free
)
end

"""
Flux.trainable(L::DirectParams)
Define trainable parameters for `DirectParams` type.
Filter empty ones (handy when nx=0)
"""
function Flux.trainable(L::DirectParams)
if L.D22_free
return filter(
p -> length(p) !=0,
[L.ρ, L.V, L.Y1, L.B2, L.D12, L.bx, L.bv]
)
end
return filter(
p -> length(p) !=0,
[L.ρ, L.V, L.Y1, L.X3, L.Y3, L.Z3, L.B2, L.D12, L.bx, L.bv]
)
end

"""
Flux.gpu(M::DirectParams{T}) where T
Add GPU compatibility for `DirectParams` type
"""
function Flux.gpu(M::DirectParams{T}) where T
if T != Float32
println("Moving type: ", T, " to gpu may not be supported. Try Float32!")
end
return DirectParams{T}(
gpu(M.V), gpu(M.Y1), gpu(M.X3), gpu(M.Y3),
gpu(M.Z3), gpu(M.B2), gpu(M.D12), gpu(M.bx),
gpu(M.bv), M.ϵ, M.polar_param
)
end

"""
Flux.cpu(M::DirectParams{T}) where T
Add CPU compatibility for `DirectParams` type
"""
function Flux.cpu(M::DirectParams{T}) where T
return DirectParams{T}(
cpu(M.V), cpu(M.Y1), cpu(M.X3), cpu(M.Y3),
cpu(M.Z3), cpu(M.B2), cpu(M.D12), cpu(M.bx),
cpu(M.bv), M.ϵ, M.polar_param
)
end
Loading

0 comments on commit 9b61781

Please sign in to comment.