forked from lava-nc/lava
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SigmaS4Delta Neuronmodel and Layer with Unittests (lava-nc#830)
* first wokring version * S4D model cleaned * update license * fix imports * linting * incorporate reviews * update docstring
- Loading branch information
Showing
9 changed files
with
713 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import numpy as np | ||
from typing import Any, Dict | ||
from lava.proc.sdn.models import AbstractSigmaDeltaModel | ||
from lava.magma.core.decorator import implements, requires, tag | ||
from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol | ||
from lava.proc.s4d.process import SigmaS4dDelta, SigmaS4dDeltaLayer | ||
from lava.magma.core.resources import CPU | ||
from lava.magma.core.model.py.ports import PyInPort, PyOutPort | ||
from lava.magma.core.model.py.type import LavaPyType | ||
from lava.magma.core.model.sub.model import AbstractSubProcessModel | ||
from lava.proc.sparse.process import Sparse | ||
|
||
|
||
class AbstractSigmaS4dDeltaModel(AbstractSigmaDeltaModel): | ||
a_in = None | ||
s_out = None | ||
|
||
# SigmaDelta Variables | ||
vth = None | ||
sigma = None | ||
act = None | ||
residue = None | ||
error = None | ||
state_exp = None | ||
bias = None | ||
|
||
# S4 Variables | ||
a = None | ||
b = None | ||
c = None | ||
s4_state = None | ||
s4_exp = None | ||
|
||
def __init__(self, proc_params: Dict[str, Any]) -> None: | ||
""" | ||
Sigma delta neuron model that implements S4D | ||
(as described by Gu et al., 2022) dynamics as its activation function. | ||
Relevant parameters in proc_params | ||
-------------------------- | ||
a: np.ndarray | ||
Diagonal elements of the state matrix of the S4D model. | ||
b: np.ndarray | ||
Diagonal elements of the input matrix of the S4D model. | ||
c: np.ndarray | ||
Diagonal elements of the output matrix of the S4D model. | ||
s4_state: np.ndarray | ||
State vector of the S4D model. | ||
""" | ||
super().__init__(proc_params) | ||
self.a = self.proc_params['a'] | ||
self.b = self.proc_params['b'] | ||
self.c = self.proc_params['c'] | ||
self.s4_state = self.proc_params['s4_state'] | ||
|
||
def activation_dynamics(self, sigma_data: np.ndarray) -> np.ndarray: | ||
"""Sigma Delta activation dynamics. Performs S4D dynamics. | ||
This function simulates the behavior of a linear time-invariant system | ||
with diagonalized state-space representation. | ||
(For reference see Gu et al., 2022) | ||
The state-space equations are given by: | ||
s4_state_{k+1} = A * s4_state_k + B * input_k | ||
act_k = C * s4_state_k | ||
where: | ||
- s4_state_k is the state vector at time step k, | ||
- input_k is the input vector at time step k, | ||
- act_k is the output vector at time step k, | ||
- A is the diagonal state matrix, | ||
- B is the diagonal input matrix, | ||
- C is the diagonal output matrix. | ||
The function computes the next output step of the | ||
system for the given input signal. | ||
Parameters | ||
---------- | ||
sigma_data: np.ndarray | ||
sigma decoded data | ||
Returns | ||
------- | ||
np.ndarray | ||
activation output | ||
""" | ||
|
||
self.s4_state = self.s4_state * self.a + sigma_data * self.b | ||
act = self.c * self.s4_state * 2 | ||
return act | ||
|
||
|
||
@implements(proc=SigmaS4dDelta, protocol=LoihiProtocol) | ||
@requires(CPU) | ||
@tag('floating_pt') | ||
class PySigmaS4dDeltaModelFloat(AbstractSigmaS4dDeltaModel): | ||
"""Floating point implementation of SigmaS4dDelta neuron.""" | ||
a_in = LavaPyType(PyInPort.VEC_DENSE, float) | ||
s_out = LavaPyType(PyOutPort.VEC_DENSE, float) | ||
|
||
vth: np.ndarray = LavaPyType(np.ndarray, float) | ||
sigma: np.ndarray = LavaPyType(np.ndarray, float) | ||
act: np.ndarray = LavaPyType(np.ndarray, float) | ||
residue: np.ndarray = LavaPyType(np.ndarray, float) | ||
error: np.ndarray = LavaPyType(np.ndarray, float) | ||
bias: np.ndarray = LavaPyType(np.ndarray, float) | ||
|
||
state_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) | ||
cum_error: np.ndarray = LavaPyType(np.ndarray, bool, precision=1) | ||
spike_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) | ||
s4_exp: np.ndarray = LavaPyType(np.ndarray, np.int32, precision=3) | ||
|
||
# S4 vaiables | ||
s4_state: np.ndarray = LavaPyType(np.ndarray, float) | ||
a: np.ndarray = LavaPyType(np.ndarray, float) | ||
b: np.ndarray = LavaPyType(np.ndarray, float) | ||
c: np.ndarray = LavaPyType(np.ndarray, float) | ||
|
||
def run_spk(self) -> None: | ||
# Receive synaptic input | ||
a_in_data = self.a_in.recv() | ||
s_out = self.dynamics(a_in_data) | ||
self.s_out.send(s_out) | ||
|
||
|
||
@implements(proc=SigmaS4dDeltaLayer, protocol=LoihiProtocol) | ||
class SubDenseLayerModel(AbstractSubProcessModel): | ||
def __init__(self, proc): | ||
"""Builds (Sparse -> S4D -> Sparse) connection of the process.""" | ||
conn_weights = proc.proc_params.get("conn_weights") | ||
shape = proc.proc_params.get("shape") | ||
state_exp = proc.proc_params.get("state_exp") | ||
num_message_bits = proc.proc_params.get("num_message_bits") | ||
s4_exp = proc.proc_params.get("s4_exp") | ||
d_states = proc.proc_params.get("d_states") | ||
a = proc.proc_params.get("a") | ||
b = proc.proc_params.get("b") | ||
c = proc.proc_params.get("c") | ||
vth = proc.proc_params.get("vth") | ||
|
||
# Instantiate processes | ||
self.sparse1 = Sparse(weights=conn_weights.T, weight_exp=state_exp, | ||
num_message_bits=num_message_bits) | ||
self.sigma_S4d_delta = SigmaS4dDelta(shape=(shape[0] * d_states,), | ||
vth=vth, | ||
state_exp=state_exp, | ||
s4_exp=s4_exp, | ||
a=a, | ||
b=b, | ||
c=c) | ||
self.sparse2 = Sparse(weights=conn_weights, weight_exp=state_exp, | ||
num_message_bits=num_message_bits) | ||
|
||
# Make connections Sparse -> SigmaS4Delta -> Sparse | ||
proc.in_ports.s_in.connect(self.sparse1.in_ports.s_in) | ||
self.sparse1.out_ports.a_out.connect(self.sigma_S4d_delta.in_ports.a_in) | ||
self.sigma_S4d_delta.out_ports.s_out.connect(self.sparse2.s_in) | ||
self.sparse2.out_ports.a_out.connect(proc.out_ports.a_out) | ||
|
||
# Set aliases | ||
proc.vars.a.alias(self.sigma_S4d_delta.vars.a) | ||
proc.vars.b.alias(self.sigma_S4d_delta.vars.b) | ||
proc.vars.c.alias(self.sigma_S4d_delta.vars.c) | ||
proc.vars.s4_state.alias(self.sigma_S4d_delta.vars.s4_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Copyright (C) 2024 Intel Corporation | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# See: https://spdx.org/licenses/ | ||
|
||
import typing as ty | ||
import numpy as np | ||
from lava.magma.core.process.process import AbstractProcess | ||
from lava.magma.core.process.variable import Var | ||
from lava.magma.core.process.ports.ports import InPort, OutPort | ||
from lava.proc.sdn.process import ActivationMode, SigmaDelta | ||
|
||
|
||
class SigmaS4dDelta(SigmaDelta, AbstractProcess): | ||
def __init__( | ||
self, | ||
shape: ty.Tuple[int, ...], | ||
vth: ty.Union[int, float], | ||
a: float, | ||
b: float, | ||
c: float, | ||
state_exp: ty.Optional[int] = 0, | ||
s4_exp: ty.Optional[int] = 0) -> None: | ||
""" | ||
Sigma delta neuron process that implements S4D (described by | ||
Gu et al., 2022) dynamics as its activation function. | ||
This process simulates the behavior of a linear time-invariant system | ||
with diagonal state-space representation. | ||
The state-space equations are given by: | ||
s4_state_{k+1} = A * s4_state_k + B * inp_k | ||
act_k = C * s4_state_k | ||
where: | ||
- s4_state_k is the state vector at time step k, | ||
- inp_k is the input vector at time step k, | ||
- act_k is the output vector at time step k, | ||
- A is the diagonal state matrix, | ||
- B is the diagonal input matrix, | ||
- C is the diagonal output matrix. | ||
Parameters | ||
---------- | ||
shape: Tuple | ||
Shape of the sigma process. | ||
vth: int or float | ||
Threshold of the delta encoder. | ||
a: np.ndarray | ||
Diagonal elements of the state matrix of the S4D model. | ||
b: np.ndarray | ||
Diagonal elements of the input matrix of the S4D model. | ||
c: np.ndarray | ||
Diagonal elements of the output matrix of the S4D model. | ||
state_exp: int | ||
Scaling exponent with base 2 for the reconstructed sigma variables. | ||
Note: This should only be used for nc models. | ||
Default is 0. | ||
s4_exp: int | ||
Scaling exponent with base 2 for the S4 state variables. | ||
Note: This should only be used for nc models. | ||
Default is 0. | ||
""" | ||
|
||
super().__init__(shape=shape, | ||
vth=vth, | ||
a=a, | ||
b=b, | ||
c=c, | ||
s4_state=0, | ||
state_exp=state_exp, | ||
s4_exp=s4_exp) | ||
|
||
# Variables for S4 | ||
self.a = Var(shape=shape, init=a) | ||
self.b = Var(shape=shape, init=b) | ||
self.c = Var(shape=shape, init=c) | ||
self.s4_state = Var(shape=shape, init=0) | ||
self.s4_exp = Var(shape=(1,), init=s4_exp) | ||
|
||
|
||
class SigmaS4dDeltaLayer(AbstractProcess): | ||
def __init__( | ||
self, | ||
shape: ty.Tuple[int, ...], | ||
vth: ty.Union[int, float], | ||
a: float, | ||
b: float, | ||
c: float, | ||
d_states: ty.Optional[int] = 1, | ||
s4_exp: ty.Optional[int] = 0, | ||
num_message_bits: ty.Optional[int] = 24, | ||
state_exp: ty.Optional[int] = 0) -> None: | ||
""" | ||
Combines S4D neuron with Sparse Processes that allow for multiple | ||
d_states. | ||
Connectivity: Sparse -> SigmaS4dDelta -> Sparse. | ||
Relieves user from computing required connection weights for multiple | ||
d_states. | ||
Parameters | ||
---------- | ||
shape: Tuple | ||
Shape of the sigma process. | ||
vth: int or float | ||
Threshold of the delta encoder. | ||
a: np.ndarray | ||
Diagonal elements of the state matrix of the S4D model. | ||
b: np.ndarray | ||
Diagonal elements of the input matrix of the S4D model. | ||
c: np.ndarray | ||
Diagonal elements of the output matrix of the S4D model. | ||
d_states: int | ||
Number of hidden states of the S4D model. | ||
Default is 1. | ||
state_exp: int | ||
Scaling exponent with base 2 for the reconstructed sigma variables. | ||
Note: Only relevant for nc model. | ||
Default is 0. | ||
num_message_bits: int | ||
Number of message bits to be used in Sparse connection processes. | ||
Note: Only relevant for nc model. | ||
s4_exp: int | ||
Scaling exponent with base 2 for the S4 state variables. | ||
Note: Only relevant for nc model. | ||
Default is 0. | ||
""" | ||
|
||
# Automatically takes care of expansion and reduction of dimensionality | ||
# for multiple hidden states (d_states) | ||
conn_weights = np.kron(np.eye(shape[0]), np.ones(d_states)) | ||
s4_state = 0 | ||
super().__init__(shape=shape, | ||
vth=vth, | ||
a=a, | ||
b=b, | ||
c=c, | ||
s4_exp=s4_exp, | ||
s4_state=s4_state, | ||
conn_weights=conn_weights, | ||
num_message_bits=num_message_bits, | ||
d_states=d_states, | ||
state_exp=state_exp, | ||
act_mode=ActivationMode.UNIT) | ||
|
||
# Ports | ||
self.s_in = InPort(shape=shape) | ||
self.a_out = OutPort(shape=shape) | ||
|
||
# General variables | ||
self.state_exp = Var(shape=(1,), init=state_exp) | ||
|
||
# Variables for S4 | ||
self.a = Var(shape=(shape[0] * d_states,), init=a) | ||
self.b = Var(shape=(shape[0] * d_states,), init=b) | ||
self.c = Var(shape=(shape[0] * d_states,), init=c) | ||
self.s4_state = Var(shape=(shape[0] * d_states,), init=0) | ||
self.S4_exp = Var(shape=(1,), init=s4_exp) | ||
|
||
# Variables for connecting Dense processes | ||
# Project input_dim to input_dim * d_states | ||
self.conn_weights = Var(shape=shape, init=conn_weights) | ||
self.num_message_bits = Var(shape=(1,), init=num_message_bits) | ||
|
||
@property | ||
def shape(self) -> ty.Tuple[int, ...]: | ||
"""Return shape of the Process.""" | ||
return self.proc_params['shape'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
Oops, something went wrong.